diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml index 96e14cfb..52ee9919 100644 --- a/.github/workflows/formatter.yml +++ b/.github/workflows/formatter.yml @@ -7,20 +7,16 @@ on: jobs: linter_name: - name: runner / black + name: runner / ruff runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Check files using the black formatter - uses: rickstaa/action-black@v1 - id: black_formatter + - uses: actions/checkout@v3 + - name: Check files using the ruff formatter + uses: chartboost/ruff-action@v1 + id: ruff_formatter with: - black_args: ". -l 120" - - name: Annotate diff changes using reviewdog - if: steps.black_formatter.outputs.is_formatted == 'true' - uses: reviewdog/action-suggester@v1 + args: format + - name: Auto commit ruff formatting + uses: stefanzweifel/git-auto-commit-action@v5 with: - tool_name: blackfmt - - name: Fail if there are formatting suggestions - if: steps.black_formatter.outputs.is_formatted == 'true' - run: exit 1 + commit_message: 'style fixes by ruff' \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..a17b4eff --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +default_language_version: + python: python3 + +repos: +# Run the Ruff linter. +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.2 + hooks: + # Run the Ruff formatter. + - id: ruff-format \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 34440a27..03eaed22 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,6 +54,8 @@ Once inside, you can install the dependencies. - Option 2: Run `pip install -e .` to install them, and data-diff, in the global context. +- Run `pre-commit install` to automatically format your code before committing. + At the bare minimum, you need MySQL to run the tests. You can create a local MySQL instance using `docker-compose up mysql`. The URI for it will be `mysql://mysql:Password1@localhost/mysql`. If you're using a different server, make sure to update `TEST_MYSQL_CONN_STRING` in `tests/common.py`. For your convenience, we recommend creating `tests/local_settings.py`, and to override the value there. diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 8caa6817..871c650d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -199,10 +199,19 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False SUPPORTS_INDEXES: ClassVar[bool] = False + PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {} PLACEHOLDER_TABLE = None # Used for Oracle + # Some database do not support long string so concatenation might lead to type overflow + + _prevent_overflow_when_concat: bool = False + + def enable_preventing_type_overflow(self) -> None: + logger.info("Preventing type overflow when concatenation is enabled") + self._prevent_overflow_when_concat = True + def parse_table_name(self, name: str) -> DbPath: "Parse the given table name into a DbPath" return parse_table_name(name) @@ -392,10 +401,19 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str: return f"sum({md5})" def render_concat(self, c: Compiler, elem: Concat) -> str: + if self._prevent_overflow_when_concat: + items = [ + f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}" + for expr in elem.exprs + ] + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [ - f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" for expr in elem.exprs - ] + else: + items = [ + f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" + for expr in elem.exprs + ] + assert items if len(items) == 1: return items[0] @@ -493,7 +511,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str: if elem.limit_expr is not None: has_order_by = bool(elem.order_by_exprs) - select += " " + self.offset_limit(0, elem.limit_expr, has_order_by=has_order_by) + select = self.limit_select(select_query=select, offset=0, limit=elem.limit_expr, has_order_by=has_order_by) if parent_c.in_select: select = f"({select}) {c.new_unique_name()}" @@ -605,14 +623,17 @@ def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str: return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}" - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + def limit_select( + self, + select_query: str, + offset: Optional[int] = None, + limit: Optional[int] = None, + has_order_by: Optional[bool] = None, ) -> str: - "Provide SQL fragment for limit and offset inside a select" if offset: raise NotImplementedError("No support for OFFSET in query") - return f"LIMIT {limit}" + return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT LIMIT {limit}" def concat(self, items: List[str]) -> str: "Provide SQL for concatenating a bunch of columns into a string" @@ -766,6 +787,10 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" + @abstractmethod + def md5_as_hex(self, s: str) -> str: + """Method to calculate MD5""" + @abstractmethod def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -882,6 +907,8 @@ class Database(abc.ABC): Instanciated using :meth:`~data_diff.connect` """ + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect + SUPPORTS_ALPHANUMS: ClassVar[bool] = True SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False CONNECT_URI_KWPARAMS: ClassVar[List[str]] = [] @@ -889,6 +916,7 @@ class Database(abc.ABC): default_schema: Optional[str] = None _interactive: bool = False is_closed: bool = False + _dialect: BaseDialect = None @property def name(self): @@ -1103,7 +1131,7 @@ def _query_cursor(self, c, sql_code: str) -> QueryResult: return result except Exception as _e: # logger.exception(e) - # logger.error(f'Caused by SQL: {sql_code}') + # logger.error(f"Caused by SQL: {sql_code}") raise def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: @@ -1117,10 +1145,13 @@ def close(self): return super().close() @property - @abstractmethod def dialect(self) -> BaseDialect: "The dialect of the database. Used internally by Database, and also available publicly." + if not self._dialect: + self._dialect = self.DIALECT_CLASS() + return self._dialect + @property @abstractmethod def CONNECT_URI_HELP(self) -> str: diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index e672b928..26d8aec3 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,5 +1,5 @@ import re -from typing import Any, List, Union +from typing import Any, ClassVar, List, Union, Type import attrs @@ -134,6 +134,9 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" @@ -179,9 +182,9 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class BigQuery(Database): + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "bigquery:///" CONNECT_URI_PARAMS = ["dataset"] - dialect = Dialect() project: str dataset: str diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 7a8816d8..13082504 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type +from typing import Any, ClassVar, Dict, Optional, Type import attrs @@ -105,6 +105,9 @@ def md5_as_int(self, s: str) -> str: f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}" ) + def md5_as_hex(self, s: str) -> str: + return f"hex(MD5({s}))" + def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. # For example: @@ -164,7 +167,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Clickhouse(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "clickhouse://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7394f2df..19a1f103 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, Sequence +from typing import Any, ClassVar, Dict, Sequence, Type import logging import attrs @@ -82,6 +82,9 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Databricks timestamp contains no more than 6 digits in precision""" @@ -104,7 +107,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Databricks(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "databricks://:@/" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index a105b71a..6c65b16b 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, ClassVar, Dict, Union, Type import attrs @@ -100,6 +100,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. if coltype.rounds and coltype.precision > 0: @@ -116,7 +119,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class DuckDB(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it CONNECT_URI_HELP = "duckdb://@" CONNECT_URI_PARAMS = ["database", "dbpath"] diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index ce5575da..8f5195ee 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Type import attrs @@ -38,7 +38,7 @@ def import_mssql(): class Dialect(BaseDialect): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES = { # Timestamps @@ -110,8 +110,12 @@ def is_distinct_from(self, a: str, b: str) -> str: # See: https://stackoverflow.com/a/18684859/857383 return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + def limit_select( + self, + select_query: str, + offset: Optional[int] = None, + limit: Optional[int] = None, + has_order_by: Optional[bool] = None, ) -> str: if offset: raise NotImplementedError("No support for OFFSET in query") @@ -121,7 +125,7 @@ def offset_limit( result += "ORDER BY 1" result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" - return result + return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}" def constant_values(self, rows) -> str: values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) @@ -147,10 +151,13 @@ def normalize_number(self, value: str, coltype: NumericType) -> str: def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"HashBytes('MD5', {s})" + @attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "mssql://:@//" CONNECT_URI_PARAMS = ["database", "schema"] diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index f4993b87..647388f2 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, ClassVar, Dict, Type import attrs @@ -40,7 +40,7 @@ def import_mysql(): class Dialect(BaseDialect): name = "MySQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES = { # Dates @@ -101,6 +101,9 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") @@ -117,7 +120,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class MySQL(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect SUPPORTS_ALPHANUMS = False SUPPORTS_UNIQUE_CONSTAINT = True CONNECT_URI_HELP = "mysql://:@/" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index a8b8b75b..ab84f0b6 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Type import attrs @@ -43,7 +43,7 @@ class Dialect( BaseDialect, ): name = "Oracle" - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES: Dict[str, type] = { "NUMBER": Decimal, @@ -64,13 +64,17 @@ def quote(self, s: str): def to_string(self, s: str): return f"cast({s} as varchar(1024))" - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + def limit_select( + self, + select_query: str, + offset: Optional[int] = None, + limit: Optional[int] = None, + has_order_by: Optional[bool] = None, ) -> str: if offset: raise NotImplementedError("No support for OFFSET in query") - return f"FETCH NEXT {limit} ROWS ONLY" + return f"SELECT * FROM ({select_query}) FETCH NEXT {limit} ROWS ONLY" def concat(self, items: List[str]) -> str: joined_exprs = " || ".join(items) @@ -133,6 +137,9 @@ def md5_as_int(self, s: str) -> str: # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"standard_hash({s}, 'MD5')" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) return f"CAST(TRIM({value}) AS VARCHAR(36))" @@ -157,7 +164,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Oracle(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "oracle://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 075d6aff..4b9e945f 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -42,7 +42,7 @@ def import_postgresql(): class PostgresqlDialect(BaseDialect): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { @@ -98,6 +98,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" @@ -119,7 +122,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class PostgreSQL(ThreadedDatabase): - dialect = PostgresqlDialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect SUPPORTS_UNIQUE_CONSTAINT = True CONNECT_URI_HELP = "postgresql://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index f575719a..ba1c7360 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,6 @@ from functools import partial import re -from typing import Any +from typing import Any, ClassVar, Type import attrs @@ -128,6 +128,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"to_hex(md5(to_utf8({s})))" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" @@ -150,7 +153,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Presto(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "presto://@//" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index b617f6d0..7a621f57 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -12,6 +12,7 @@ TimestampTZ, ) from data_diff.databases.postgresql import ( + BaseDialect, PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -47,6 +48,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"{value}::timestamp(6)" @@ -76,7 +80,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Redshift(PostgreSQL): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "redshift://:@/" CONNECT_URI_PARAMS = ["database?"] @@ -126,9 +130,7 @@ def select_view_columns(self, path: DbPath) -> str: return """select * from pg_get_cols('{}.{}') cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) - """.format( - schema, table - ) + """.format(schema, table) def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: rows = self.query(self.select_view_columns(path), list) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 857e7c89..bedacd80 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Any, Union, List +from typing import Any, ClassVar, Union, List, Type import logging import attrs @@ -76,6 +76,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" @@ -93,7 +96,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Snowflake(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "snowflake://:@//?warehouse=" CONNECT_URI_PARAMS = ["database", "schema"] CONNECT_URI_KWPARAMS = ["warehouse"] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index f0c95ee4..b76ba74b 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,11 +1,11 @@ -from typing import Any +from typing import Any, ClassVar, Type import attrs from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto from data_diff.databases.base import import_helper -from data_diff.databases.base import TIMESTAMP_PRECISION_POS +from data_diff.databases.base import TIMESTAMP_PRECISION_POS, BaseDialect @import_helper("trino") @@ -34,7 +34,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Trino(presto.Presto): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "trino://@//" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 51dc00fa..23f63acc 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List, Type import attrs @@ -36,6 +36,7 @@ def import_vertica(): return vertica_python +@attrs.define(frozen=False) class Dialect(BaseDialect): name = "Vertica" ROUNDS_ON_PREC_LOSS = True @@ -109,6 +110,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"MD5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" @@ -131,7 +135,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Vertica(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "vertica://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 44daba34..66802426 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -208,6 +208,10 @@ def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) + if table1.database.dialect.PREVENT_OVERFLOW_WHEN_CONCAT or table2.database.dialect.PREVENT_OVERFLOW_WHEN_CONCAT: + table1.database.dialect.enable_preventing_type_overflow() + table2.database.dialect.enable_preventing_type_overflow() + start = time.monotonic() error = None try: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 9c140a5f..3bf14fd7 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -337,6 +337,7 @@ def else_(self, then: Expr) -> Self: @attrs.define(frozen=True, eq=False) class QB_When: "Partial case-when, used for query-building" + casewhen: CaseWhen when: Expr diff --git a/data_diff/version.py b/data_diff/version.py index 5b709086..1819ca0f 100644 --- a/data_diff/version.py +++ b/data_diff/version.py @@ -1 +1 @@ -__version__ = "0.9.12" +__version__ = "0.9.13" diff --git a/poetry.lock b/poetry.lock index 8fa89af1..16e4495c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -195,6 +195,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "2.0.12" @@ -545,6 +556,17 @@ python-dateutil = ">=2.0,<3.0" pyyaml = ">=6.0,<7.0" typing-extensions = ">=4.0,<5.0" +[[package]] +name = "distlib" +version = "0.3.7" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, + {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, +] + [[package]] name = "dsnparse" version = "0.1.15" @@ -651,6 +673,20 @@ files = [ jsonschema = ">=3.0" python-dateutil = ">=2.8,<2.9" +[[package]] +name = "identify" +version = "2.5.31" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.31-py2.py3-none-any.whl", hash = "sha256:90199cb9e7bd3c5407a9b7e81b4abec4bb9d249991c79439ec8af740afc6293d"}, + {file = "identify-2.5.31.tar.gz", hash = "sha256:7736b3c7a28233637e3c36550646fc6389bedd74ae84cb788200cc8e2dd60b75"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.4" @@ -1104,6 +1140,20 @@ doc = ["nb2plots (>=0.6)", "numpydoc (>=1.1)", "pillow (>=8.2)", "pydata-sphinx- extra = ["lxml (>=4.5)", "pydot (>=1.4.1)", "pygraphviz (>=1.7)"] test = ["codecov (>=2.1)", "pytest (>=6.2)", "pytest-cov (>=2.12)"] +[[package]] +name = "nodeenv" +version = "1.8.0" +description = "Node.js virtual environment builder" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +files = [ + {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, + {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, +] + +[package.dependencies] +setuptools = "*" + [[package]] name = "oracledb" version = "1.3.2" @@ -1211,6 +1261,39 @@ files = [ {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, ] +[[package]] +name = "platformdirs" +version = "3.11.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"}, + {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"}, +] + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] + +[[package]] +name = "pre-commit" +version = "3.5.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, + {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "preql" version = "0.2.19" @@ -1724,6 +1807,32 @@ pygments = ">=2.6.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] +[[package]] +name = "ruff" +version = "0.1.4" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.1.4-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:864958706b669cce31d629902175138ad8a069d99ca53514611521f532d91495"}, + {file = "ruff-0.1.4-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:9fdd61883bb34317c788af87f4cd75dfee3a73f5ded714b77ba928e418d6e39e"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4eaca8c9cc39aa7f0f0d7b8fe24ecb51232d1bb620fc4441a61161be4a17539"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9a1301dc43cbf633fb603242bccd0aaa34834750a14a4c1817e2e5c8d60de17"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78e8db8ab6f100f02e28b3d713270c857d370b8d61871d5c7d1702ae411df683"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:80fea754eaae06335784b8ea053d6eb8e9aac75359ebddd6fee0858e87c8d510"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6bc02a480d4bfffd163a723698da15d1a9aec2fced4c06f2a753f87f4ce6969c"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862811b403063765b03e716dac0fda8fdbe78b675cd947ed5873506448acea4"}, + {file = "ruff-0.1.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58826efb8b3efbb59bb306f4b19640b7e366967a31c049d49311d9eb3a4c60cb"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:fdfd453fc91d9d86d6aaa33b1bafa69d114cf7421057868f0b79104079d3e66e"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e8791482d508bd0b36c76481ad3117987301b86072158bdb69d796503e1c84a8"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01206e361021426e3c1b7fba06ddcb20dbc5037d64f6841e5f2b21084dc51800"}, + {file = "ruff-0.1.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:645591a613a42cb7e5c2b667cbefd3877b21e0252b59272ba7212c3d35a5819f"}, + {file = "ruff-0.1.4-py3-none-win32.whl", hash = "sha256:99908ca2b3b85bffe7e1414275d004917d1e0dfc99d497ccd2ecd19ad115fd0d"}, + {file = "ruff-0.1.4-py3-none-win_amd64.whl", hash = "sha256:1dfd6bf8f6ad0a4ac99333f437e0ec168989adc5d837ecd38ddb2cc4a2e3db8a"}, + {file = "ruff-0.1.4-py3-none-win_arm64.whl", hash = "sha256:d98ae9ebf56444e18a3e3652b3383204748f73e247dea6caaf8b52d37e6b32da"}, + {file = "ruff-0.1.4.tar.gz", hash = "sha256:21520ecca4cc555162068d87c747b8f95e1e95f8ecfcbbe59e8dd00710586315"}, +] + [[package]] name = "runtype" version = "0.2.7" @@ -2002,6 +2111,26 @@ files = [ python-dateutil = ">=1.5" six = ">=1.10.0" +[[package]] +name = "virtualenv" +version = "20.24.6" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.24.6-py3-none-any.whl", hash = "sha256:520d056652454c5098a00c0f073611ccbea4c79089331f60bf9d7ba247bb7381"}, + {file = "virtualenv-20.24.6.tar.gz", hash = "sha256:02ece4f56fbf939dbbc33c0715159951d6bf14aaf5457b092e4548e1382455af"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<4" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "wcwidth" version = "0.2.5" @@ -2045,4 +2174,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "925ab3b5dc0ef5a9ddb125e7a3bcf4800bee5d5d3ef67d73a48cfae587fa4505" +content-hash = "b7f8880be9658fa523ff4737d1fdefd09d124ab09fac22fa9ec6a83454940c1d" diff --git a/pyproject.toml b/pyproject.toml index 091d8fb2..e8952a5c 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.9.12" +version = "0.9.13" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Datafold "] license = "MIT" @@ -61,6 +61,7 @@ clickhouse-driver = "*" vertica-python = "*" duckdb = "^0.7.0" dbt-core = "^1.0.0" +ruff = "^0.1.4" # google-cloud-bigquery = "*" # databricks-sql-connector = "*" @@ -80,6 +81,9 @@ clickhouse = ["clickhouse-driver"] vertica = ["vertica-python"] duckdb = ["duckdb"] +[tool.poetry.group.dev.dependencies] +pre-commit = "^3.5.0" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..870cba20 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,2 @@ +# Allow lines to be as long as 120. +line-length = 120 \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index 018f2e32..2fc6be19 100644 --- a/tests/common.py +++ b/tests/common.py @@ -135,6 +135,7 @@ def str_to_checksum(str: str): class DiffTestCase(unittest.TestCase): "Sets up two tables for diffing" + db_cls = None src_schema = None dst_schema = None diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index ed8a31b6..0f664c45 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -270,7 +270,9 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list(d for d in TEST_DATABASES if d.dialect.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +@test_each_database_in_list( + d for d in TEST_DATABASES if d.DIALECT_CLASS.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT +) class TestUniqueConstraint(DiffTestCase): def setUp(self): super().setUp() diff --git a/tests/test_query.py b/tests/test_query.py index 69900b4b..2585c02e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -50,11 +50,16 @@ def current_database(self) -> str: def current_schema(self) -> str: return "current_schema()" - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + def limit_select( + self, + select_query: str, + offset: Optional[int] = None, + limit: Optional[int] = None, + has_order_by: Optional[bool] = None, ) -> str: x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}" - return " ".join(filter(None, x)) + result = " ".join(filter(None, x)) + return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}" def explain_as_text(self, query: str) -> str: return f"explain {query}" @@ -71,6 +76,9 @@ def optimizer_hints(self, s: str): def md5_as_int(self, s: str) -> str: raise NotImplementedError + def md5_as_hex(self, s: str) -> str: + raise NotImplementedError + def normalize_number(self, value: str, coltype: FractionalType) -> str: raise NotImplementedError @@ -192,7 +200,7 @@ def test_funcs(self): t = table("a") q = c.compile(t.order_by(Random()).limit(10)) - self.assertEqual(q, "SELECT * FROM a ORDER BY random() LIMIT 10") + self.assertEqual(q, "SELECT * FROM (SELECT * FROM a ORDER BY random()) AS LIMITED_SELECT LIMIT 10") q = c.compile(t.select(coalesce(this.a, this.b))) self.assertEqual(q, "SELECT COALESCE(a, b) FROM a") @@ -210,7 +218,7 @@ def test_select_distinct(self): # selects stay apart q = c.compile(t.limit(10).select(this.b, distinct=True)) - self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM a LIMIT 10) tmp1") + self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1") q = c.compile(t.select(this.b, distinct=True).select(distinct=False)) self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2") @@ -226,7 +234,9 @@ def test_select_with_optimizer_hints(self): self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)") q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)")) - self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM a LIMIT 10) tmp1") + self.assertEqual( + q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1" + ) q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)")) self.assertEqual( pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy