Skip to content

Commit 73f2209

Browse files
Add fetchmany to execute many *and* return rows (#1175)
Co-authored-by: Elvis Pranskevichus <elvis@edgedb.com>
1 parent b732b4f commit 73f2209

File tree

8 files changed

+153
-8
lines changed

8 files changed

+153
-8
lines changed

asyncpg/connection.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,44 @@ async def fetchrow(
756756
return None
757757
return data[0]
758758

759+
async def fetchmany(
760+
self, query, args, *, timeout: float=None, record_class=None
761+
):
762+
"""Run a query for each sequence of arguments in *args*
763+
and return the results as a list of :class:`Record`.
764+
765+
:param query:
766+
Query to execute.
767+
:param args:
768+
An iterable containing sequences of arguments for the query.
769+
:param float timeout:
770+
Optional timeout value in seconds.
771+
:param type record_class:
772+
If specified, the class to use for records returned by this method.
773+
Must be a subclass of :class:`~asyncpg.Record`. If not specified,
774+
a per-connection *record_class* is used.
775+
776+
:return list:
777+
A list of :class:`~asyncpg.Record` instances. If specified, the
778+
actual type of list elements would be *record_class*.
779+
780+
Example:
781+
782+
.. code-block:: pycon
783+
784+
>>> rows = await con.fetchmany('''
785+
... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
786+
... ''', [('x', 1), ('y', 2), ('z', 3)])
787+
>>> rows
788+
[<Record row=('x',)>, <Record row=('y',)>, <Record row=('z',)>]
789+
790+
.. versionadded:: 0.30.0
791+
"""
792+
self._check_open()
793+
return await self._executemany(
794+
query, args, timeout, return_rows=True, record_class=record_class
795+
)
796+
759797
async def copy_from_table(self, table_name, *, output,
760798
columns=None, schema_name=None, timeout=None,
761799
format=None, oids=None, delimiter=None,
@@ -1896,17 +1934,27 @@ async def __execute(
18961934
)
18971935
return result, stmt
18981936

1899-
async def _executemany(self, query, args, timeout):
1937+
async def _executemany(
1938+
self,
1939+
query,
1940+
args,
1941+
timeout,
1942+
return_rows=False,
1943+
record_class=None,
1944+
):
19001945
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
19011946
state=stmt,
19021947
args=args,
19031948
portal_name='',
19041949
timeout=timeout,
1950+
return_rows=return_rows,
19051951
)
19061952
timeout = self._protocol._get_timeout(timeout)
19071953
with self._stmt_exclusive_section:
19081954
with self._time_and_log(query, args, timeout):
1909-
result, _ = await self._do_execute(query, executor, timeout)
1955+
result, _ = await self._do_execute(
1956+
query, executor, timeout, record_class=record_class
1957+
)
19101958
return result
19111959

19121960
async def _do_execute(

asyncpg/pool.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,22 @@ async def fetchrow(self, query, *args, timeout=None, record_class=None):
609609
record_class=record_class
610610
)
611611

612+
async def fetchmany(self, query, args, *, timeout=None, record_class=None):
613+
"""Run a query for each sequence of arguments in *args*
614+
and return the results as a list of :class:`Record`.
615+
616+
Pool performs this operation using one of its connections. Other than
617+
that, it behaves identically to
618+
:meth:`Connection.fetchmany()
619+
<asyncpg.connection.Connection.fetchmany>`.
620+
621+
.. versionadded:: 0.30.0
622+
"""
623+
async with self.acquire() as con:
624+
return await con.fetchmany(
625+
query, args, timeout=timeout, record_class=record_class
626+
)
627+
612628
async def copy_from_table(
613629
self,
614630
table_name,

asyncpg/prepared_stmt.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,27 @@ async def fetchrow(self, *args, timeout=None):
210210
return None
211211
return data[0]
212212

213+
@connresource.guarded
214+
async def fetchmany(self, args, *, timeout=None):
215+
"""Execute the statement and return a list of :class:`Record` objects.
216+
217+
:param args: Query arguments.
218+
:param float timeout: Optional timeout value in seconds.
219+
220+
:return: A list of :class:`Record` instances.
221+
222+
.. versionadded:: 0.30.0
223+
"""
224+
return await self.__do_execute(
225+
lambda protocol: protocol.bind_execute_many(
226+
self._state,
227+
args,
228+
portal_name='',
229+
timeout=timeout,
230+
return_rows=True,
231+
)
232+
)
233+
213234
@connresource.guarded
214235
async def executemany(self, args, *, timeout: float=None):
215236
"""Execute the statement for each sequence of arguments in *args*.
@@ -222,7 +243,12 @@ async def executemany(self, args, *, timeout: float=None):
222243
"""
223244
return await self.__do_execute(
224245
lambda protocol: protocol.bind_execute_many(
225-
self._state, args, '', timeout))
246+
self._state,
247+
args,
248+
portal_name='',
249+
timeout=timeout,
250+
return_rows=False,
251+
))
226252

227253
async def __do_execute(self, executor):
228254
protocol = self._connection._protocol

asyncpg/protocol/coreproto.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ cdef class CoreProtocol:
171171
cdef _bind_execute(self, str portal_name, str stmt_name,
172172
WriteBuffer bind_data, int32_t limit)
173173
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
174-
object bind_data)
174+
object bind_data, bint return_rows)
175175
cdef bint _bind_execute_many_more(self, bint first=*)
176176
cdef _bind_execute_many_fail(self, object error, bint first=*)
177177
cdef _bind(self, str portal_name, str stmt_name,

asyncpg/protocol/coreproto.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,12 +1020,12 @@ cdef class CoreProtocol:
10201020
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
10211021

10221022
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
1023-
object bind_data):
1023+
object bind_data, bint return_rows):
10241024
self._ensure_connected()
10251025
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
10261026

1027-
self.result = None
1028-
self._discard_data = True
1027+
self.result = [] if return_rows else None
1028+
self._discard_data = not return_rows
10291029
self._execute_iter = bind_data
10301030
self._execute_portal_name = portal_name
10311031
self._execute_stmt_name = stmt_name

asyncpg/protocol/protocol.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ cdef class BaseProtocol(CoreProtocol):
212212
args,
213213
portal_name: str,
214214
timeout,
215+
return_rows: bool,
215216
):
216217
if self.cancel_waiter is not None:
217218
await self.cancel_waiter
@@ -237,7 +238,8 @@ cdef class BaseProtocol(CoreProtocol):
237238
more = self._bind_execute_many(
238239
portal_name,
239240
state.name,
240-
arg_bufs) # network op
241+
arg_bufs,
242+
return_rows) # network op
241243

242244
self.last_query = state.query
243245
self.statement = state

tests/test_execute.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,45 @@ async def test_executemany_basic(self):
139139
('a', 1), ('b', 2), ('c', 3), ('d', 4)
140140
])
141141

142+
async def test_executemany_returning(self):
143+
result = await self.con.fetchmany('''
144+
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
145+
''', [
146+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
147+
])
148+
self.assertEqual(result, [
149+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
150+
])
151+
result = await self.con.fetch('''
152+
SELECT * FROM exmany
153+
''')
154+
self.assertEqual(result, [
155+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
156+
])
157+
158+
# Empty set
159+
await self.con.fetchmany('''
160+
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
161+
''', ())
162+
result = await self.con.fetch('''
163+
SELECT * FROM exmany
164+
''')
165+
self.assertEqual(result, [
166+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
167+
])
168+
169+
# Without "RETURNING"
170+
result = await self.con.fetchmany('''
171+
INSERT INTO exmany VALUES($1, $2)
172+
''', [('e', 5), ('f', 6)])
173+
self.assertEqual(result, [])
174+
result = await self.con.fetch('''
175+
SELECT * FROM exmany
176+
''')
177+
self.assertEqual(result, [
178+
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
179+
])
180+
142181
async def test_executemany_bad_input(self):
143182
with self.assertRaisesRegex(
144183
exceptions.DataError,

tests/test_prepare.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,17 @@ async def test_prepare_explicitly_named(self):
611611
'prepared statement "foobar" already exists',
612612
):
613613
await self.con.prepare('select 1', name='foobar')
614+
615+
async def test_prepare_fetchmany(self):
616+
tr = self.con.transaction()
617+
await tr.start()
618+
try:
619+
await self.con.execute('CREATE TABLE fetchmany (a int, b text)')
620+
621+
stmt = await self.con.prepare(
622+
'INSERT INTO fetchmany (a, b) VALUES ($1, $2) RETURNING a, b'
623+
)
624+
result = await stmt.fetchmany([(1, 'a'), (2, 'b'), (3, 'c')])
625+
self.assertEqual(result, [(1, 'a'), (2, 'b'), (3, 'c')])
626+
finally:
627+
await tr.rollback()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy