tokio_postgres/
query.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::prepare::get_type;
5use crate::types::{BorrowToSql, IsNull};
6use crate::{Column, Error, Portal, Row, Statement};
7use bytes::{Bytes, BytesMut};
8use fallible_iterator::FallibleIterator;
9use futures_util::{ready, Stream};
10use log::{debug, log_enabled, Level};
11use pin_project_lite::pin_project;
12use postgres_protocol::message::backend::{CommandCompleteBody, Message};
13use postgres_protocol::message::frontend;
14use postgres_types::Type;
15use std::fmt;
16use std::marker::PhantomPinned;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
22
23impl<T> fmt::Debug for BorrowToSqlParamsDebug<'_, T>
24where
25    T: BorrowToSql,
26{
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_list()
29            .entries(self.0.iter().map(|x| x.borrow_to_sql()))
30            .finish()
31    }
32}
33
34pub async fn query<P, I>(
35    client: &InnerClient,
36    statement: Statement,
37    params: I,
38) -> Result<RowStream, Error>
39where
40    P: BorrowToSql,
41    I: IntoIterator<Item = P>,
42    I::IntoIter: ExactSizeIterator,
43{
44    let buf = if log_enabled!(Level::Debug) {
45        let params = params.into_iter().collect::<Vec<_>>();
46        debug!(
47            "executing statement {} with parameters: {:?}",
48            statement.name(),
49            BorrowToSqlParamsDebug(params.as_slice()),
50        );
51        encode(client, &statement, params)?
52    } else {
53        encode(client, &statement, params)?
54    };
55    let responses = start(client, buf).await?;
56    Ok(RowStream {
57        statement,
58        responses,
59        rows_affected: None,
60        _p: PhantomPinned,
61    })
62}
63
64pub async fn query_typed<P, I>(
65    client: &Arc<InnerClient>,
66    query: &str,
67    params: I,
68) -> Result<RowStream, Error>
69where
70    P: BorrowToSql,
71    I: IntoIterator<Item = (P, Type)>,
72{
73    let buf = {
74        let params = params.into_iter().collect::<Vec<_>>();
75        let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
76
77        client.with_buf(|buf| {
78            frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;
79            encode_bind_raw("", params, "", buf)?;
80            frontend::describe(b'S', "", buf).map_err(Error::encode)?;
81            frontend::execute("", 0, buf).map_err(Error::encode)?;
82            frontend::sync(buf);
83
84            Ok(buf.split().freeze())
85        })?
86    };
87
88    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
89
90    loop {
91        match responses.next().await? {
92            Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {}
93            Message::NoData => {
94                return Ok(RowStream {
95                    statement: Statement::unnamed(vec![], vec![]),
96                    responses,
97                    rows_affected: None,
98                    _p: PhantomPinned,
99                });
100            }
101            Message::RowDescription(row_description) => {
102                let mut columns: Vec<Column> = vec![];
103                let mut it = row_description.fields();
104                while let Some(field) = it.next().map_err(Error::parse)? {
105                    let type_ = get_type(client, field.type_oid()).await?;
106                    let column = Column {
107                        name: field.name().to_string(),
108                        table_oid: Some(field.table_oid()).filter(|n| *n != 0),
109                        column_id: Some(field.column_id()).filter(|n| *n != 0),
110                        r#type: type_,
111                    };
112                    columns.push(column);
113                }
114                return Ok(RowStream {
115                    statement: Statement::unnamed(vec![], columns),
116                    responses,
117                    rows_affected: None,
118                    _p: PhantomPinned,
119                });
120            }
121            _ => return Err(Error::unexpected_message()),
122        }
123    }
124}
125
126pub async fn query_portal(
127    client: &InnerClient,
128    portal: &Portal,
129    max_rows: i32,
130) -> Result<RowStream, Error> {
131    let buf = client.with_buf(|buf| {
132        frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
133        frontend::sync(buf);
134        Ok(buf.split().freeze())
135    })?;
136
137    let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
138
139    Ok(RowStream {
140        statement: portal.statement().clone(),
141        responses,
142        rows_affected: None,
143        _p: PhantomPinned,
144    })
145}
146
147/// Extract the number of rows affected from [`CommandCompleteBody`].
148pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
149    let rows = body
150        .tag()
151        .map_err(Error::parse)?
152        .rsplit(' ')
153        .next()
154        .unwrap()
155        .parse()
156        .unwrap_or(0);
157    Ok(rows)
158}
159
160pub async fn execute<P, I>(
161    client: &InnerClient,
162    statement: Statement,
163    params: I,
164) -> Result<u64, Error>
165where
166    P: BorrowToSql,
167    I: IntoIterator<Item = P>,
168    I::IntoIter: ExactSizeIterator,
169{
170    let buf = if log_enabled!(Level::Debug) {
171        let params = params.into_iter().collect::<Vec<_>>();
172        debug!(
173            "executing statement {} with parameters: {:?}",
174            statement.name(),
175            BorrowToSqlParamsDebug(params.as_slice()),
176        );
177        encode(client, &statement, params)?
178    } else {
179        encode(client, &statement, params)?
180    };
181    let mut responses = start(client, buf).await?;
182
183    let mut rows = 0;
184    loop {
185        match responses.next().await? {
186            Message::DataRow(_) => {}
187            Message::CommandComplete(body) => {
188                rows = extract_row_affected(&body)?;
189            }
190            Message::EmptyQueryResponse => rows = 0,
191            Message::ReadyForQuery(_) => return Ok(rows),
192            _ => return Err(Error::unexpected_message()),
193        }
194    }
195}
196
197async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
198    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
199
200    match responses.next().await? {
201        Message::BindComplete => {}
202        _ => return Err(Error::unexpected_message()),
203    }
204
205    Ok(responses)
206}
207
208pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
209where
210    P: BorrowToSql,
211    I: IntoIterator<Item = P>,
212    I::IntoIter: ExactSizeIterator,
213{
214    client.with_buf(|buf| {
215        encode_bind(statement, params, "", buf)?;
216        frontend::execute("", 0, buf).map_err(Error::encode)?;
217        frontend::sync(buf);
218        Ok(buf.split().freeze())
219    })
220}
221
222pub fn encode_bind<P, I>(
223    statement: &Statement,
224    params: I,
225    portal: &str,
226    buf: &mut BytesMut,
227) -> Result<(), Error>
228where
229    P: BorrowToSql,
230    I: IntoIterator<Item = P>,
231    I::IntoIter: ExactSizeIterator,
232{
233    let params = params.into_iter();
234    if params.len() != statement.params().len() {
235        return Err(Error::parameters(params.len(), statement.params().len()));
236    }
237
238    encode_bind_raw(
239        statement.name(),
240        params.zip(statement.params().iter().cloned()),
241        portal,
242        buf,
243    )
244}
245
246fn encode_bind_raw<P, I>(
247    statement_name: &str,
248    params: I,
249    portal: &str,
250    buf: &mut BytesMut,
251) -> Result<(), Error>
252where
253    P: BorrowToSql,
254    I: IntoIterator<Item = (P, Type)>,
255    I::IntoIter: ExactSizeIterator,
256{
257    let (param_formats, params): (Vec<_>, Vec<_>) = params
258        .into_iter()
259        .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty)))
260        .unzip();
261
262    let mut error_idx = 0;
263    let r = frontend::bind(
264        portal,
265        statement_name,
266        param_formats,
267        params.into_iter().enumerate(),
268        |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) {
269            Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
270            Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
271            Err(e) => {
272                error_idx = idx;
273                Err(e)
274            }
275        },
276        Some(1),
277        buf,
278    );
279    match r {
280        Ok(()) => Ok(()),
281        Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
282        Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
283    }
284}
285
286pin_project! {
287    /// A stream of table rows.
288    pub struct RowStream {
289        statement: Statement,
290        responses: Responses,
291        rows_affected: Option<u64>,
292        #[pin]
293        _p: PhantomPinned,
294    }
295}
296
297impl Stream for RowStream {
298    type Item = Result<Row, Error>;
299
300    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301        let this = self.project();
302        loop {
303            match ready!(this.responses.poll_next(cx)?) {
304                Message::DataRow(body) => {
305                    return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
306                }
307                Message::CommandComplete(body) => {
308                    *this.rows_affected = Some(extract_row_affected(&body)?);
309                }
310                Message::EmptyQueryResponse | Message::PortalSuspended => {}
311                Message::ReadyForQuery(_) => return Poll::Ready(None),
312                _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
313            }
314        }
315    }
316}
317
318impl RowStream {
319    /// Returns the number of rows affected by the query.
320    ///
321    /// This function will return `None` until the stream has been exhausted.
322    pub fn rows_affected(&self) -> Option<u64> {
323        self.rows_affected
324    }
325}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy