Skip to content

Commit 494ee68

Browse files
committed
➕ bulk_update, tests and raw_sql
* The bulk_update is an adaptation of encode/orm#148
1 parent 640fb03 commit 494ee68

File tree

7 files changed

+308
-22
lines changed

7 files changed

+308
-22
lines changed

saffier/core/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import typing
2+
from enum import Enum
23
from inspect import isclass
34

5+
from orjson import OPT_OMIT_MICROSECONDS # noqa
6+
from orjson import OPT_SERIALIZE_NUMPY # noqa
7+
from orjson import dumps
48
from typing_extensions import get_origin
59

610
from saffier.fields import DateField, DateTimeField
@@ -21,6 +25,16 @@ def _update_auto_now_fields(self, values: DictAny, fields: DictAny) -> DictAny:
2125
values[k] = v.validator.get_default_value()
2226
return values
2327

28+
def _resolve_value(self, value: typing.Any):
29+
if isinstance(value, dict):
30+
return dumps(
31+
value,
32+
option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS,
33+
).decode("utf-8")
34+
elif isinstance(value, Enum):
35+
return value.name
36+
return value
37+
2438

2539
def is_class_and_subclass(value: typing.Any, _type: typing.Any) -> bool:
2640
original = get_origin(value)

saffier/db/fields.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def validate(self, value: typing.Any) -> typing.Any:
145145
elif not isinstance(value, str):
146146
raise self.validation_error("type")
147147

148-
# The null character is always invalid.
149148
value = value.replace("\0", "")
150149

151150
if self.trim_whitespace:

saffier/db/queryset.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from saffier.types import DictAny
1313

1414
if typing.TYPE_CHECKING: # pragma: no cover
15+
from saffier.db.connection import Database
1516
from saffier.models import Model
1617

1718

@@ -25,7 +26,7 @@ class QuerySetProps:
2526
"""
2627

2728
@property
28-
def database(self):
29+
def database(self) -> "Database":
2930
return self.model_class._meta.registry.database
3031

3132
@property
@@ -117,6 +118,7 @@ def _build_select(self):
117118
if self.distinct_on:
118119
expression = self._build_select_distinct(self.distinct_on, expression=expression)
119120

121+
setattr(self, "_expression", expression)
120122
return expression
121123

122124
def _filter_query(self, exclude: bool = False, **kwargs):
@@ -232,6 +234,7 @@ def _clone(self) -> "QuerySet[SaffierModel]":
232234
queryset._order_by = copy.copy(self._order_by)
233235
queryset._group_by = copy.copy(self._group_by)
234236
queryset.distinct_on = copy.copy(self.distinct_on)
237+
queryset._expression = self._expression
235238
return queryset
236239

237240

@@ -262,14 +265,30 @@ def __init__(
262265
self._order_by = [] if order_by is None else order_by
263266
self._group_by = [] if group_by is None else group_by
264267
self.distinct_on = [] if distinct_on is None else distinct_on
268+
self._expression = None
265269

266270
def __get__(self, instance, owner):
267271
return self.__class__(model_class=owner)
268272

273+
@property
274+
def sql(self):
275+
return str(self._expression)
276+
277+
@sql.setter
278+
def sql(self, value):
279+
setattr(self, "_expression", value)
280+
269281
async def __aiter__(self) -> typing.AsyncIterator[SaffierModel]:
270282
for value in await self:
271283
yield value
272284

285+
def _set_query_expression(self, expression: typing.Any) -> None:
286+
"""
287+
Sets the value of the sql property to the expression used.
288+
"""
289+
self.sql = expression
290+
self.model_class.raw_query = self.sql
291+
273292
def _filter_or_exclude(
274293
self,
275294
clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None,
@@ -389,6 +408,7 @@ async def exists(self) -> bool:
389408
"""
390409
expression = self._build_select()
391410
expression = sqlalchemy.exists(expression).select()
411+
self._set_query_expression(expression)
392412
return await self.database.fetch_val(expression)
393413

394414
async def count(self) -> int:
@@ -397,6 +417,7 @@ async def count(self) -> int:
397417
"""
398418
expression = self._build_select().alias("subquery_for_count")
399419
expression = sqlalchemy.func.count().select().select_from(expression)
420+
self._set_query_expression(expression)
400421
return await self.database.fetch_val(expression)
401422

402423
async def get_or_none(self, **kwargs):
@@ -405,6 +426,7 @@ async def get_or_none(self, **kwargs):
405426
"""
406427
queryset = self.filter(**kwargs)
407428
expression = queryset._build_select().limit(2)
429+
self._set_query_expression(expression)
408430
rows = await self.database.fetch_all(expression)
409431

410432
if not rows:
@@ -422,7 +444,12 @@ async def all(self, **kwargs):
422444
return await queryset.filter(**kwargs).all()
423445

424446
expression = queryset._build_select()
447+
self._set_query_expression(expression)
448+
425449
rows = await queryset.database.fetch_all(expression)
450+
451+
# Attach the raw query to the object
452+
queryset.model_class.raw_query = self.sql
426453
return [
427454
queryset.model_class._from_row(row, select_related=self._select_related)
428455
for row in rows
@@ -437,6 +464,7 @@ async def get(self, **kwargs):
437464

438465
expression = self._build_select().limit(2)
439466
rows = await self.database.fetch_all(expression)
467+
self._set_query_expression(expression)
440468

441469
if not rows:
442470
raise DoesNotFound()
@@ -475,6 +503,7 @@ async def create(self, **kwargs):
475503
kwargs = self._validate_kwargs(**kwargs)
476504
instance = self.model_class(**kwargs)
477505
expression = self.table.insert().values(**kwargs)
506+
self._set_query_expression(expression)
478507

479508
if self.pkname not in kwargs:
480509
instance.pk = await self.database.execute(expression)
@@ -490,13 +519,57 @@ async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
490519
new_objs = [self._validate_kwargs(**obj) for obj in objs]
491520

492521
expression = self.table.insert().values(new_objs)
522+
self._set_query_expression(expression)
493523
await self.database.execute(expression)
494524

525+
async def bulk_update(self, objs: typing.List[SaffierModel], fields: typing.List[str]) -> None:
526+
"""
527+
Bulk updates records in a table.
528+
529+
A similar solution was suggested here: https://github.com/encode/orm/pull/148
530+
531+
It is thought to be a clean approach to a simple problem so it was added here and
532+
refactored to be compatible with Saffier.
533+
"""
534+
new_fields = {}
535+
for key, field in self.model_class.fields.items():
536+
if key in fields:
537+
new_fields[key] = field.validator
538+
539+
validator = Schema(fields=new_fields)
540+
541+
new_objs = []
542+
for obj in objs:
543+
new_obj = {}
544+
for key, value in obj.__dict__.items():
545+
if key in fields:
546+
new_obj[key] = self._resolve_value(value)
547+
new_objs.append(new_obj)
548+
549+
new_objs = [
550+
self._update_auto_now_fields(validator.validate(obj), self.model_class.fields)
551+
for obj in new_objs
552+
]
553+
554+
pk = getattr(self.table.c, self.pkname)
555+
expression = self.table.update().where(pk == sqlalchemy.bindparam(self.pkname))
556+
kwargs = {field: sqlalchemy.bindparam(field) for obj in new_objs for field in obj.keys()}
557+
pks = [{self.pkname: getattr(obj, self.pkname)} for obj in objs]
558+
559+
query_list = []
560+
for pk, value in zip(pks, new_objs):
561+
query_list.append({**pk, **value})
562+
563+
expression = expression.values(kwargs)
564+
self._set_query_expression(expression)
565+
await self.database.execute_many(str(expression), query_list)
566+
495567
async def delete(self) -> None:
496568
expression = self.table.delete()
497569
for filter_clause in self.filter_clauses:
498570
expression = expression.where(filter_clause)
499571

572+
self._set_query_expression(expression)
500573
await self.database.execute(expression)
501574

502575
async def update(self, **kwargs) -> None:
@@ -509,12 +582,13 @@ async def update(self, **kwargs) -> None:
509582

510583
validator = Schema(fields=fields)
511584
kwargs = self._update_auto_now_fields(validator.validate(kwargs), self.model_class.fields)
512-
expr = self.table.update().values(**kwargs)
585+
expression = self.table.update().values(**kwargs)
513586

514587
for filter_clause in self.filter_clauses:
515-
expr = expr.where(filter_clause)
588+
expression = expression.where(filter_clause)
516589

517-
await self.database.execute(expr)
590+
self._set_query_expression(expression)
591+
await self.database.execute(expression)
518592

519593
async def get_or_create(
520594
self, defaults: typing.Dict[str, typing.Any], **kwargs

saffier/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Model(ModelMeta, ModelUtil):
1818
query = Manager()
1919
_meta = MetaInfo(None)
2020
_db_model: bool = False
21+
_raw_query: str = None
2122

2223
def __init__(self, **kwargs: DictAny) -> None:
2324
if "pk" in kwargs:
@@ -55,6 +56,14 @@ def pk(self):
5556
def pk(self, value):
5657
setattr(self, self.pkname, value)
5758

59+
@property
60+
def raw_query(self):
61+
return getattr(self, self._raw_query)
62+
63+
@raw_query.setter
64+
def raw_query(self, value):
65+
setattr(self, self.raw_query, value)
66+
5867
def __repr__(self):
5968
return f"<{self.__class__.__name__}: {self}>"
6069

tests/models/test_bulk_create.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import datetime
2+
from enum import Enum
3+
4+
import pytest
5+
from tests.settings import DATABASE_URL
6+
7+
import saffier
8+
from saffier import fields
9+
from saffier.db.connection import Database
10+
11+
pytestmark = pytest.mark.anyio
12+
13+
database = Database(DATABASE_URL)
14+
models = saffier.Registry(database=database)
15+
16+
17+
def time():
18+
return datetime.datetime.now().time()
19+
20+
21+
class StatusEnum(Enum):
22+
DRAFT = "Draft"
23+
RELEASED = "Released"
24+
25+
26+
class Product(saffier.Model):
27+
id = fields.IntegerField(primary_key=True)
28+
uuid = fields.UUIDField(null=True)
29+
created = fields.DateTimeField(default=datetime.datetime.now)
30+
created_day = fields.DateField(default=datetime.date.today)
31+
created_time = fields.TimeField(default=time)
32+
created_date = fields.DateField(auto_now_add=True)
33+
created_datetime = fields.DateTimeField(auto_now_add=True)
34+
updated_datetime = fields.DateTimeField(auto_now=True)
35+
updated_date = fields.DateField(auto_now=True)
36+
data = fields.JSONField(default={})
37+
description = fields.CharField(blank=True, max_length=255)
38+
huge_number = fields.BigIntegerField(default=0)
39+
price = fields.DecimalField(max_digits=5, decimal_places=2, null=True)
40+
status = fields.ChoiceField(StatusEnum, default=StatusEnum.DRAFT)
41+
value = fields.FloatField(null=True)
42+
43+
class Meta:
44+
registry = models
45+
46+
47+
@pytest.fixture(autouse=True, scope="module")
48+
async def create_test_database():
49+
await models.create_all()
50+
yield
51+
await models.drop_all()
52+
53+
54+
@pytest.fixture(autouse=True)
55+
async def rollback_transactions():
56+
with database.force_rollback():
57+
async with database:
58+
yield
59+
60+
61+
async def test_bulk_create():
62+
await Product.query.bulk_create(
63+
[
64+
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
65+
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
66+
]
67+
)
68+
products = await Product.query.all()
69+
assert len(products) == 2
70+
assert products[0].data == {"foo": 123}
71+
assert products[0].value == 123.456
72+
assert products[0].status == StatusEnum.RELEASED
73+
assert products[1].data == {"foo": 456}
74+
assert products[1].value == 456.789
75+
assert products[1].status == StatusEnum.DRAFT

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy