diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index c5931979..bf165461 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -931,7 +931,7 @@ def name(self): def compile(self, sql_ast): return self.dialect.compile(Compiler(self), sql_ast) - def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None): """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' If given a generator, it will execute all the yielded sql queries with the same thread and cursor. @@ -956,7 +956,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s):\n%s", self.name, sql_code) + if log_message: + logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code) + else: + logger.debug("Running SQL (%s):\n%s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): explained_sql = self.compile(Explain(sql_ast)) @@ -1022,7 +1025,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: Note: This method exists instead of select_table_schema(), just because not all databases support accessing the schema using a SQL query. """ - rows = self.query(self.select_table_schema(path), list) + rows = self.query(self.select_table_schema(path), list, log_message=path) if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -1044,7 +1047,7 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]: """Query the table for its unique columns for table in 'path', and return {column}""" if not self.SUPPORTS_UNIQUE_CONSTAINT: raise NotImplementedError("This database doesn't support 'unique' constraints") - res = self.query(self.select_table_unique_columns(path), List[str]) + res = self.query(self.select_table_unique_columns(path), List[str], log_message=path) return list(res) def _process_table_schema( @@ -1086,7 +1089,9 @@ def _refine_coltypes( fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] samples_by_row = self.query( - table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list + table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), + list, + log_message=table_path, ) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/dbt.py b/data_diff/dbt.py index bf36c4fc..ef780429 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -8,6 +8,7 @@ import pydantic import rich from rich.prompt import Prompt +from concurrent.futures import ThreadPoolExecutor, as_completed from data_diff.errors import ( DataDiffCustomSchemaNoConfigError, @@ -80,7 +81,6 @@ def dbt_diff( production_schema_flag: Optional[str] = None, ) -> None: print_version_info() - diff_threads = [] set_entrypoint_name(os.getenv("DATAFOLD_TRIGGERED_BY", "CLI-dbt")) dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state) models = dbt_parser.get_models(dbt_selection) @@ -112,7 +112,11 @@ def dbt_diff( else: dbt_parser.set_connection() - with log_status_handler.status if log_status_handler else nullcontext(): + futures = {} + + with log_status_handler.status if log_status_handler else nullcontext(), ThreadPoolExecutor( + max_workers=dbt_parser.threads + ) as executor: for model in models: if log_status_handler: log_status_handler.set_prefix(f"Diffing {model.alias} \n") @@ -140,12 +144,12 @@ def dbt_diff( if diff_vars.primary_keys: if is_cloud: - diff_thread = run_as_daemon( + future = executor.submit( _cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler ) - diff_threads.append(diff_thread) else: - _local_diff(diff_vars, json_output) + future = executor.submit(_local_diff, diff_vars, json_output, log_status_handler) + futures[future] = model else: if json_output: print( @@ -165,10 +169,12 @@ def dbt_diff( + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n" ) - # wait for all threads - if diff_threads: - for thread in diff_threads: - thread.join() + for future in as_completed(futures): + model = futures[future] + try: + future.result() # if error occurred, it will be raised here + except Exception as e: + logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}") _extension_notification() @@ -265,15 +271,17 @@ def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str, return prod_database, prod_schema, prod_alias -def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: +def _local_diff( + diff_vars: TDiffVars, json_output: bool = False, log_status_handler: Optional[LogStatusHandler] = None +) -> None: + if log_status_handler: + log_status_handler.diff_started(diff_vars.dev_path[-1]) dev_qualified_str = ".".join(diff_vars.dev_path) prod_qualified_str = ".".join(diff_vars.prod_path) diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str) - table1 = connect_to_table( - diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads - ) - table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads) + table1 = connect_to_table(diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys)) + table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys)) try: table1_columns = table1.get_schema() @@ -373,6 +381,9 @@ def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: diff_output_str += no_differences_template() rich.print(diff_output_str) + if log_status_handler: + log_status_handler.diff_finished(diff_vars.dev_path[-1]) + def _initialize_api() -> Optional[DatafoldAPI]: datafold_host = os.environ.get("DATAFOLD_HOST") @@ -406,7 +417,7 @@ def _cloud_diff( log_status_handler: Optional[LogStatusHandler] = None, ) -> None: if log_status_handler: - log_status_handler.cloud_diff_started(diff_vars.dev_path[-1]) + log_status_handler.diff_started(diff_vars.dev_path[-1]) diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path)) payload = TCloudApiDataDiff( data_source1_id=datasource_id, @@ -476,7 +487,7 @@ def _cloud_diff( rich.print(diff_output_str) if log_status_handler: - log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1]) + log_status_handler.diff_finished(diff_vars.dev_path[-1]) except BaseException as ex: # Catch KeyboardInterrupt too error = ex finally: diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index 4b6124d5..0d864a57 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -446,17 +446,17 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None if from_meta: - logger.debug("Found PKs via META: " + str(from_meta)) + logger.debug(f"Found PKs via META [{node.name}]: " + str(from_meta)) return from_meta from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None if from_tags: - logger.debug("Found PKs via Tags: " + str(from_tags)) + logger.debug(f"Found PKs via Tags [{node.name}]: " + str(from_tags)) return from_tags if node.unique_id in unique_columns: from_uniq = unique_columns.get(node.unique_id) if from_uniq is not None: - logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq)) + logger.debug(f"Found PKs via Uniqueness tests [{node.name}]: {str(from_uniq)}") return list(from_uniq) except (KeyError, IndexError, TypeError) as e: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 6fadc5d8..8e7fcf30 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -162,7 +162,7 @@ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tre yield from self._diff_segments(None, table1, table2, info_tree, None) else: yield from self._bisect_and_diff_tables(table1, table2, info_tree) - logger.info("Diffing complete") + logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}") if self.materialize_to_table: logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) @@ -193,8 +193,8 @@ def _diff_segments( partial(self._collect_stats, 1, table1, info_tree), partial(self._collect_stats, 2, table2, info_tree), partial(self._test_null_keys, table1, table2), - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2), partial( self._materialize_diff, db, @@ -205,8 +205,8 @@ def _diff_segments( else None, ): assert len(a_cols) == len(b_cols) - logger.debug("Querying for different rows") - diff = db.query(diff_rows, list) + logger.debug(f"Querying for different rows: {table1.table_path}") + diff = db.query(diff_rows, list, log_message=table1.table_path) info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items())) for is_xa, is_xb, *x in diff: if is_xa and is_xb: @@ -227,7 +227,7 @@ def _diff_segments( yield "+", tuple(b_row) def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): - logger.debug("Testing for duplicate keys") + logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}") # Test duplicate keys for ts in [table1, table2]: @@ -240,16 +240,16 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): unvalidated = list(set(key_columns) - set(unique)) if unvalidated: - logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated}") + logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}") # Validate that there are no duplicate keys self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) - total, total_distinct = ts.database.query(q, tuple) + total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path) if total != total_distinct: raise ValueError("Duplicate primary keys") def _test_null_keys(self, table1, table2): - logger.debug("Testing for null keys") + logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}") # Test null keys for ts in [table1, table2]: @@ -257,7 +257,7 @@ def _test_null_keys(self, table1, table2): key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) - nulls = ts.database.query(q, list) + nulls = ts.database.query(q, list, log_message=ts.table_path) if nulls: if self.skip_null_keys: logger.warning( @@ -267,7 +267,7 @@ def _test_null_keys(self, table1, table2): raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}") def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): - logger.debug(f"Collecting stats for table #{i}") + logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}") db = table_seg.database # Metrics @@ -288,7 +288,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): ) col_exprs["count"] = Count() - res = db.query(table_seg.make_select().select(**col_exprs), tuple) + res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path) for col_name, value in safezip(col_exprs, res): if value is not None: @@ -303,7 +303,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): else: self.stats[stat_name] = value - logger.debug("Done collecting stats for table #%s", i) + logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path) def _create_outer_join(self, table1, table2): db = table1.database @@ -334,23 +334,46 @@ def _create_outer_join(self, table1, table2): diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols)) return diff_rows, a_cols, b_cols, is_diff_cols, all_rows - def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): - logger.debug("Counting differences per column") - is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) + def _count_diff_per_column( + self, + db, + diff_rows, + cols, + is_diff_cols, + table1: Optional[TableSegment] = None, + table2: Optional[TableSegment] = None, + ): + logger.info(type(table1)) + logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}") + is_diff_cols_counts = db.query( + diff_rows.select(sum_(this[c]) for c in is_diff_cols), + tuple, + log_message=f"{table1.table_path} <> {table2.table_path}", + ) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): diff_counts[name] = diff_counts.get(name, 0) + (count or 0) self.stats["diff_counts"] = diff_counts - def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): + def _sample_and_count_exclusive( + self, + db, + diff_rows, + a_cols, + b_cols, + table1: Optional[TableSegment] = None, + table2: Optional[TableSegment] = None, + ): if isinstance(db, (Oracle, MsSQL)): exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) if not self.sample_exclusive_rows: - logger.debug("Counting exclusive rows") - self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int) + logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}") + self.stats["exclusive_count"] = db.query( + exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}" + ) return logger.info("Counting and sampling exclusive rows") diff --git a/data_diff/utils.py b/data_diff/utils.py index ee4a0f17..b9045cc1 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -485,31 +485,31 @@ def __init__(self): super().__init__() self.status = Status("") self.prefix = "" - self.cloud_diff_status = {} + self.diff_status = {} def emit(self, record): log_entry = self.format(record) - if self.cloud_diff_status: - self._update_cloud_status(log_entry) + if self.diff_status: + self._update_diff_status(log_entry) else: self.status.update(self.prefix + log_entry) def set_prefix(self, prefix_string): self.prefix = prefix_string - def cloud_diff_started(self, model_name): - self.cloud_diff_status[model_name] = "[yellow]In Progress[/]" - self._update_cloud_status() + def diff_started(self, model_name): + self.diff_status[model_name] = "[yellow]In Progress[/]" + self._update_diff_status() - def cloud_diff_finished(self, model_name): - self.cloud_diff_status[model_name] = "[green]Finished [/]" - self._update_cloud_status() + def diff_finished(self, model_name): + self.diff_status[model_name] = "[green]Finished [/]" + self._update_diff_status() - def _update_cloud_status(self, log=None): - cloud_status_string = "\n" - for model_name, status in self.cloud_diff_status.items(): - cloud_status_string += f"{status} {model_name}\n" - self.status.update(f"{cloud_status_string}{log or ''}") + def _update_diff_status(self, log=None): + status_string = "\n" + for model_name, status in self.diff_status.items(): + status_string += f"{status} {model_name}\n" + self.status.update(f"{status_string}{log or ''}") class UnknownMeta(type): diff --git a/tests/test_dbt.py b/tests/test_dbt.py index c281b6fb..31af99eb 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -93,8 +93,8 @@ def test_local_diff(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys), threads) - mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), threads) + mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys)) + mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys)) mock_diff.get_stats_string.assert_called_once() @patch("data_diff.dbt.diff_tables") @@ -180,8 +180,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys), None) - mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), None) + mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys)) + mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys)) mock_diff.get_stats_string.assert_not_called() @patch("data_diff.dbt.rich.print") @@ -248,6 +248,7 @@ def test_diff_is_cloud( where = "a_string" config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_model = Mock() mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake") mock_initialize_api.return_value = mock_api @@ -386,6 +387,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -407,7 +409,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._get_diff_vars") @@ -423,6 +425,7 @@ def test_diff_state_model_dne( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -460,6 +463,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -481,7 +485,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._get_diff_vars") @@ -497,6 +501,7 @@ def test_diff_only_prod_schema( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -518,7 +523,7 @@ def test_diff_only_prod_schema( mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._initialize_api") @@ -543,6 +548,7 @@ def test_diff_is_cloud_no_pks( mock_model = Mock() connection = {} threads = None + mock_dbt_parser_inst.threads = threads where = "a_string" config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_api = Mock() @@ -584,6 +590,7 @@ def test_diff_not_is_cloud_no_pks( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model]
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: