Skip to content

Commit e41e1db

Browse files
committed
Ignore custom data codec for internal introspection
Fixes: #617
1 parent 68b40cb commit e41e1db

File tree

9 files changed

+83
-32
lines changed

9 files changed

+83
-32
lines changed

asyncpg/connection.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,16 @@ async def _get_statement(
342342
*,
343343
named: bool=False,
344344
use_cache: bool=True,
345+
ignore_custom_codec=False,
345346
record_class=None
346347
):
347348
if record_class is None:
348349
record_class = self._protocol.get_record_class()
349350

350351
if use_cache:
351-
statement = self._stmt_cache.get((query, record_class))
352+
statement = self._stmt_cache.get(
353+
(query, record_class, ignore_custom_codec)
354+
)
352355
if statement is not None:
353356
return statement
354357

@@ -371,6 +374,7 @@ async def _get_statement(
371374
query,
372375
timeout,
373376
record_class=record_class,
377+
ignore_custom_codec=ignore_custom_codec,
374378
)
375379
need_reprepare = False
376380
types_with_missing_codecs = statement._init_types()
@@ -415,7 +419,8 @@ async def _get_statement(
415419
)
416420

417421
if use_cache:
418-
self._stmt_cache.put((query, record_class), statement)
422+
self._stmt_cache.put(
423+
(query, record_class, ignore_custom_codec), statement)
419424

420425
# If we've just created a new statement object, check if there
421426
# are any statements for GC.
@@ -426,7 +431,12 @@ async def _get_statement(
426431

427432
async def _introspect_types(self, typeoids, timeout):
428433
return await self.__execute(
429-
self._intro_query, (list(typeoids),), 0, timeout)
434+
self._intro_query,
435+
(list(typeoids),),
436+
0,
437+
timeout,
438+
ignore_custom_codec=True,
439+
)
430440

431441
async def _introspect_type(self, typename, schema):
432442
if (
@@ -439,20 +449,22 @@ async def _introspect_type(self, typename, schema):
439449
[typeoid],
440450
limit=0,
441451
timeout=None,
452+
ignore_custom_codec=True,
442453
)
443-
if rows:
444-
typeinfo = rows[0]
445-
else:
446-
typeinfo = None
447454
else:
448-
typeinfo = await self.fetchrow(
449-
introspection.TYPE_BY_NAME, typename, schema)
455+
rows = await self._execute(
456+
introspection.TYPE_BY_NAME,
457+
[typename, schema],
458+
limit=1,
459+
timeout=None,
460+
ignore_custom_codec=True,
461+
)
450462

451-
if not typeinfo:
463+
if not rows:
452464
raise ValueError(
453465
'unknown type: {}.{}'.format(schema, typename))
454466

455-
return typeinfo
467+
return rows[0]
456468

457469
def cursor(
458470
self,
@@ -1325,7 +1337,9 @@ def _mark_stmts_as_closed(self):
13251337
def _maybe_gc_stmt(self, stmt):
13261338
if (
13271339
stmt.refs == 0
1328-
and not self._stmt_cache.has((stmt.query, stmt.record_class))
1340+
and not self._stmt_cache.has(
1341+
(stmt.query, stmt.record_class, stmt.ignore_custom_codec)
1342+
)
13291343
):
13301344
# If low-level `stmt` isn't referenced from any high-level
13311345
# `PreparedStatement` object and is not in the `_stmt_cache`:
@@ -1589,6 +1603,7 @@ async def _execute(
15891603
timeout,
15901604
*,
15911605
return_status=False,
1606+
ignore_custom_codec=False,
15921607
record_class=None
15931608
):
15941609
with self._stmt_exclusive_section:
@@ -1599,6 +1614,7 @@ async def _execute(
15991614
timeout,
16001615
return_status=return_status,
16011616
record_class=record_class,
1617+
ignore_custom_codec=ignore_custom_codec,
16021618
)
16031619
return result
16041620

@@ -1610,6 +1626,7 @@ async def __execute(
16101626
timeout,
16111627
*,
16121628
return_status=False,
1629+
ignore_custom_codec=False,
16131630
record_class=None
16141631
):
16151632
executor = lambda stmt, timeout: self._protocol.bind_execute(
@@ -1620,6 +1637,7 @@ async def __execute(
16201637
executor,
16211638
timeout,
16221639
record_class=record_class,
1640+
ignore_custom_codec=ignore_custom_codec,
16231641
)
16241642

16251643
async def _executemany(self, query, args, timeout):
@@ -1637,20 +1655,23 @@ async def _do_execute(
16371655
timeout,
16381656
retry=True,
16391657
*,
1658+
ignore_custom_codec=False,
16401659
record_class=None
16411660
):
16421661
if timeout is None:
16431662
stmt = await self._get_statement(
16441663
query,
16451664
None,
16461665
record_class=record_class,
1666+
ignore_custom_codec=ignore_custom_codec,
16471667
)
16481668
else:
16491669
before = time.monotonic()
16501670
stmt = await self._get_statement(
16511671
query,
16521672
timeout,
16531673
record_class=record_class,
1674+
ignore_custom_codec=ignore_custom_codec,
16541675
)
16551676
after = time.monotonic()
16561677
timeout -= after - before

asyncpg/protocol/codecs/base.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,6 @@ cdef class DataCodecConfig:
166166
dict _derived_type_codecs
167167
dict _custom_type_codecs
168168

169-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
169+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
170+
bint ignore_custom_codec=*)
170171
cdef inline Codec get_any_local_codec(self, uint32_t oid)

asyncpg/protocol/codecs/base.pyx

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,20 @@ cdef class DataCodecConfig:
692692

693693
return codec
694694

695-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
695+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
696+
bint ignore_custom_codec=False):
696697
cdef Codec codec
697698

698-
codec = self.get_any_local_codec(oid)
699-
if codec is not None:
700-
if codec.format != format:
701-
# The codec for this OID has been overridden by
702-
# set_{builtin}_type_codec with a different format.
703-
# We must respect that and not return a core codec.
704-
return None
705-
else:
706-
return codec
699+
if not ignore_custom_codec:
700+
codec = self.get_any_local_codec(oid)
701+
if codec is not None:
702+
if codec.format != format:
703+
# The codec for this OID has been overridden by
704+
# set_{builtin}_type_codec with a different format.
705+
# We must respect that and not return a core codec.
706+
return None
707+
else:
708+
return codec
707709

708710
codec = get_core_codec(oid, format)
709711
if codec is not None:

asyncpg/protocol/prepared_stmt.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cdef class PreparedStatementState:
1212
readonly bint closed
1313
readonly int refs
1414
readonly type record_class
15+
readonly bint ignore_custom_codec
1516

1617

1718
list row_desc

asyncpg/protocol/prepared_stmt.pyx

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ cdef class PreparedStatementState:
1616
str name,
1717
str query,
1818
BaseProtocol protocol,
19-
type record_class
19+
type record_class,
20+
bint ignore_custom_codec
2021
):
2122
self.name = name
2223
self.query = query
@@ -28,6 +29,7 @@ cdef class PreparedStatementState:
2829
self.closed = False
2930
self.refs = 0
3031
self.record_class = record_class
32+
self.ignore_custom_codec = ignore_custom_codec
3133

3234
def _get_parameters(self):
3335
cdef Codec codec
@@ -87,6 +89,7 @@ cdef class PreparedStatementState:
8789
return missing
8890

8991
cpdef _init_codecs(self):
92+
9093
self._ensure_args_encoder()
9194
self._ensure_rows_decoder()
9295

@@ -205,7 +208,8 @@ cdef class PreparedStatementState:
205208
cols_mapping[col_name] = i
206209
cols_names.append(col_name)
207210
oid = row[3]
208-
codec = self.settings.get_data_codec(oid)
211+
codec = self.settings.get_data_codec(
212+
oid, ignore_custom_codec=self.ignore_custom_codec)
209213
if codec is None or not codec.has_decoder():
210214
raise exceptions.InternalClientError(
211215
'no decoder for OID {}'.format(oid))
@@ -230,7 +234,8 @@ cdef class PreparedStatementState:
230234

231235
for i from 0 <= i < self.args_num:
232236
p_oid = self.parameters_desc[i]
233-
codec = self.settings.get_data_codec(p_oid)
237+
codec = self.settings.get_data_codec(
238+
p_oid, ignore_custom_codec=self.ignore_custom_codec)
234239
if codec is None or not codec.has_encoder():
235240
raise exceptions.InternalClientError(
236241
'no encoder for OID {}'.format(p_oid))

asyncpg/protocol/protocol.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ cdef class BaseProtocol(CoreProtocol):
145145
async def prepare(self, stmt_name, query, timeout,
146146
*,
147147
PreparedStatementState state=None,
148+
ignore_custom_codec=False,
148149
record_class):
149150
if self.cancel_waiter is not None:
150151
await self.cancel_waiter
@@ -161,7 +162,7 @@ cdef class BaseProtocol(CoreProtocol):
161162
self.last_query = query
162163
if state is None:
163164
state = PreparedStatementState(
164-
stmt_name, query, self, record_class)
165+
stmt_name, query, self, record_class, ignore_custom_codec)
165166
self.statement = state
166167
except Exception as ex:
167168
waiter.set_exception(ex)

asyncpg/protocol/settings.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext):
2626
cpdef inline set_builtin_type_codec(
2727
self, typeoid, typename, typeschema, typekind, alias_to, format)
2828
cpdef inline Codec get_data_codec(
29-
self, uint32_t oid, ServerDataFormat format=*)
29+
self, uint32_t oid, ServerDataFormat format=*,
30+
bint ignore_custom_codec=*)

asyncpg/protocol/settings.pyx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext):
8787
typekind, alias_to, _format)
8888

8989
cpdef inline Codec get_data_codec(self, uint32_t oid,
90-
ServerDataFormat format=PG_FORMAT_ANY):
90+
ServerDataFormat format=PG_FORMAT_ANY,
91+
bint ignore_custom_codec=False):
9192
if format == PG_FORMAT_ANY:
92-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY)
93+
codec = self._data_codecs.get_codec(
94+
oid, PG_FORMAT_BINARY, ignore_custom_codec)
9395
if codec is None:
94-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT)
96+
codec = self._data_codecs.get_codec(
97+
oid, PG_FORMAT_TEXT, ignore_custom_codec)
9598
return codec
9699
else:
97-
return self._data_codecs.get_codec(oid, format)
100+
return self._data_codecs.get_codec(
101+
oid, format, ignore_custom_codec)
98102

99103
def __getattr__(self, name):
100104
if not name.startswith('_'):

tests/test_introspection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ def tearDownClass(cls):
4343

4444
super().tearDownClass()
4545

46+
def setUp(self):
47+
super().setUp()
48+
self.loop.run_until_complete(self._add_custom_codec(self.con))
49+
50+
async def _add_custom_codec(self, conn):
51+
# mess up with the codec - builtin introspection shouldn't be affected
52+
await conn.set_type_codec(
53+
"oid",
54+
schema="pg_catalog",
55+
encoder=lambda value: None,
56+
decoder=lambda value: None,
57+
format="text",
58+
)
59+
4660
@tb.with_connection_options(database='asyncpg_intro_test')
4761
async def test_introspection_on_large_db(self):
4862
await self.con.execute(
@@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self):
142156
# query would cause introspection to retry.
143157
slow_intro_conn = await self.connect(
144158
connection_class=SlowIntrospectionConnection)
159+
await self._add_custom_codec(slow_intro_conn)
145160
try:
146161
await self.con.execute('''
147162
CREATE DOMAIN intro_1_t AS int;

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy