tokio_postgres/
connection.rs

1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_in::CopyInReceiver;
3use crate::error::DbError;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::{AsyncMessage, Error, Notification};
6use bytes::BytesMut;
7use fallible_iterator::FallibleIterator;
8use futures_channel::mpsc;
9use futures_util::{ready, stream::FusedStream, Sink, Stream, StreamExt};
10use log::{info, trace};
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use std::collections::{HashMap, VecDeque};
14use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_util::codec::Framed;
19
20pub enum RequestMessages {
21    Single(FrontendMessage),
22    CopyIn(CopyInReceiver),
23}
24
25pub struct Request {
26    pub messages: RequestMessages,
27    pub sender: mpsc::Sender<BackendMessages>,
28}
29
30pub struct Response {
31    sender: mpsc::Sender<BackendMessages>,
32}
33
34#[derive(PartialEq, Debug)]
35enum State {
36    Active,
37    Terminating,
38    Closing,
39}
40
41/// A connection to a PostgreSQL database.
42///
43/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
44/// server, and should generally be spawned off onto an executor to run in the background.
45///
46/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
47/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
48#[must_use = "futures do nothing unless polled"]
49pub struct Connection<S, T> {
50    stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
51    parameters: HashMap<String, String>,
52    receiver: mpsc::UnboundedReceiver<Request>,
53    pending_request: Option<RequestMessages>,
54    pending_responses: VecDeque<BackendMessage>,
55    responses: VecDeque<Response>,
56    state: State,
57}
58
59impl<S, T> Connection<S, T>
60where
61    S: AsyncRead + AsyncWrite + Unpin,
62    T: AsyncRead + AsyncWrite + Unpin,
63{
64    pub(crate) fn new(
65        stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
66        pending_responses: VecDeque<BackendMessage>,
67        parameters: HashMap<String, String>,
68        receiver: mpsc::UnboundedReceiver<Request>,
69    ) -> Connection<S, T> {
70        Connection {
71            stream,
72            parameters,
73            receiver,
74            pending_request: None,
75            pending_responses,
76            responses: VecDeque::new(),
77            state: State::Active,
78        }
79    }
80
81    fn poll_response(
82        &mut self,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<Result<BackendMessage, Error>>> {
85        if let Some(message) = self.pending_responses.pop_front() {
86            trace!("retrying pending response");
87            return Poll::Ready(Some(Ok(message)));
88        }
89
90        Pin::new(&mut self.stream)
91            .poll_next(cx)
92            .map(|o| o.map(|r| r.map_err(Error::io)))
93    }
94
95    fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
96        if self.state != State::Active {
97            trace!("poll_read: done");
98            return Ok(None);
99        }
100
101        loop {
102            let message = match self.poll_response(cx)? {
103                Poll::Ready(Some(message)) => message,
104                Poll::Ready(None) => return Err(Error::closed()),
105                Poll::Pending => {
106                    trace!("poll_read: waiting on response");
107                    return Ok(None);
108                }
109            };
110
111            let (mut messages, request_complete) = match message {
112                BackendMessage::Async(Message::NoticeResponse(body)) => {
113                    let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
114                    return Ok(Some(AsyncMessage::Notice(error)));
115                }
116                BackendMessage::Async(Message::NotificationResponse(body)) => {
117                    let notification = Notification {
118                        process_id: body.process_id(),
119                        channel: body.channel().map_err(Error::parse)?.to_string(),
120                        payload: body.message().map_err(Error::parse)?.to_string(),
121                    };
122                    return Ok(Some(AsyncMessage::Notification(notification)));
123                }
124                BackendMessage::Async(Message::ParameterStatus(body)) => {
125                    self.parameters.insert(
126                        body.name().map_err(Error::parse)?.to_string(),
127                        body.value().map_err(Error::parse)?.to_string(),
128                    );
129                    continue;
130                }
131                BackendMessage::Async(_) => unreachable!(),
132                BackendMessage::Normal {
133                    messages,
134                    request_complete,
135                } => (messages, request_complete),
136            };
137
138            let mut response = match self.responses.pop_front() {
139                Some(response) => response,
140                None => match messages.next().map_err(Error::parse)? {
141                    Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
142                    _ => return Err(Error::unexpected_message()),
143                },
144            };
145
146            match response.sender.poll_ready(cx) {
147                Poll::Ready(Ok(())) => {
148                    let _ = response.sender.start_send(messages);
149                    if !request_complete {
150                        self.responses.push_front(response);
151                    }
152                }
153                Poll::Ready(Err(_)) => {
154                    // we need to keep paging through the rest of the messages even if the receiver's hung up
155                    if !request_complete {
156                        self.responses.push_front(response);
157                    }
158                }
159                Poll::Pending => {
160                    self.responses.push_front(response);
161                    self.pending_responses.push_back(BackendMessage::Normal {
162                        messages,
163                        request_complete,
164                    });
165                    trace!("poll_read: waiting on sender");
166                    return Ok(None);
167                }
168            }
169        }
170    }
171
172    fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
173        if let Some(messages) = self.pending_request.take() {
174            trace!("retrying pending request");
175            return Poll::Ready(Some(messages));
176        }
177
178        if self.receiver.is_terminated() {
179            return Poll::Ready(None);
180        }
181
182        match self.receiver.poll_next_unpin(cx) {
183            Poll::Ready(Some(request)) => {
184                trace!("polled new request");
185                self.responses.push_back(Response {
186                    sender: request.sender,
187                });
188                Poll::Ready(Some(request.messages))
189            }
190            Poll::Ready(None) => Poll::Ready(None),
191            Poll::Pending => Poll::Pending,
192        }
193    }
194
195    fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
196        loop {
197            if self.state == State::Closing {
198                trace!("poll_write: done");
199                return Ok(false);
200            }
201
202            if Pin::new(&mut self.stream)
203                .poll_ready(cx)
204                .map_err(Error::io)?
205                .is_pending()
206            {
207                trace!("poll_write: waiting on socket");
208                return Ok(false);
209            }
210
211            let request = match self.poll_request(cx) {
212                Poll::Ready(Some(request)) => request,
213                Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
214                    trace!("poll_write: at eof, terminating");
215                    self.state = State::Terminating;
216                    let mut request = BytesMut::new();
217                    frontend::terminate(&mut request);
218                    RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
219                }
220                Poll::Ready(None) => {
221                    trace!(
222                        "poll_write: at eof, pending responses {}",
223                        self.responses.len()
224                    );
225                    return Ok(true);
226                }
227                Poll::Pending => {
228                    trace!("poll_write: waiting on request");
229                    return Ok(true);
230                }
231            };
232
233            match request {
234                RequestMessages::Single(request) => {
235                    Pin::new(&mut self.stream)
236                        .start_send(request)
237                        .map_err(Error::io)?;
238                    if self.state == State::Terminating {
239                        trace!("poll_write: sent eof, closing");
240                        self.state = State::Closing;
241                    }
242                }
243                RequestMessages::CopyIn(mut receiver) => {
244                    let message = match receiver.poll_next_unpin(cx) {
245                        Poll::Ready(Some(message)) => message,
246                        Poll::Ready(None) => {
247                            trace!("poll_write: finished copy_in request");
248                            continue;
249                        }
250                        Poll::Pending => {
251                            trace!("poll_write: waiting on copy_in stream");
252                            self.pending_request = Some(RequestMessages::CopyIn(receiver));
253                            return Ok(true);
254                        }
255                    };
256                    Pin::new(&mut self.stream)
257                        .start_send(message)
258                        .map_err(Error::io)?;
259                    self.pending_request = Some(RequestMessages::CopyIn(receiver));
260                }
261            }
262        }
263    }
264
265    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
266        match Pin::new(&mut self.stream)
267            .poll_flush(cx)
268            .map_err(Error::io)?
269        {
270            Poll::Ready(()) => trace!("poll_flush: flushed"),
271            Poll::Pending => trace!("poll_flush: waiting on socket"),
272        }
273        Ok(())
274    }
275
276    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
277        if self.state != State::Closing {
278            return Poll::Pending;
279        }
280
281        match Pin::new(&mut self.stream)
282            .poll_close(cx)
283            .map_err(Error::io)?
284        {
285            Poll::Ready(()) => {
286                trace!("poll_shutdown: complete");
287                Poll::Ready(Ok(()))
288            }
289            Poll::Pending => {
290                trace!("poll_shutdown: waiting on socket");
291                Poll::Pending
292            }
293        }
294    }
295
296    /// Returns the value of a runtime parameter for this connection.
297    pub fn parameter(&self, name: &str) -> Option<&str> {
298        self.parameters.get(name).map(|s| &**s)
299    }
300
301    /// Polls for asynchronous messages from the server.
302    ///
303    /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
304    /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
305    ///
306    /// Return values of `None` or `Some(Err(_))` are "terminal"; callers should not invoke this method again after
307    /// receiving one of those values.
308    pub fn poll_message(
309        &mut self,
310        cx: &mut Context<'_>,
311    ) -> Poll<Option<Result<AsyncMessage, Error>>> {
312        let message = self.poll_read(cx)?;
313        let want_flush = self.poll_write(cx)?;
314        if want_flush {
315            self.poll_flush(cx)?;
316        }
317        match message {
318            Some(message) => Poll::Ready(Some(Ok(message))),
319            None => match self.poll_shutdown(cx) {
320                Poll::Ready(Ok(())) => Poll::Ready(None),
321                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
322                Poll::Pending => Poll::Pending,
323            },
324        }
325    }
326}
327
328impl<S, T> Future for Connection<S, T>
329where
330    S: AsyncRead + AsyncWrite + Unpin,
331    T: AsyncRead + AsyncWrite + Unpin,
332{
333    type Output = Result<(), Error>;
334
335    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
336        while let Some(message) = ready!(self.poll_message(cx)?) {
337            if let AsyncMessage::Notice(notice) = message {
338                info!("{}: {}", notice.severity(), notice.message());
339            }
340        }
341        Poll::Ready(Ok(()))
342    }
343}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy