From 35da209a3c899f0ed962c8b8d59b1d400cd9a359 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 22:06:07 +0100 Subject: [PATCH 01/44] Automatic pre-commit fixes --- datafaker/base.py | 54 +- datafaker/create.py | 30 +- datafaker/dump.py | 21 +- datafaker/generators.py | 601 ++++-- datafaker/interactive.py | 559 +++-- datafaker/main.py | 110 +- datafaker/make.py | 110 +- datafaker/providers.py | 8 +- datafaker/remove.py | 21 +- datafaker/serialize_metadata.py | 73 +- datafaker/utils.py | 109 +- docs/source/_static/config_schema.html | 2760 +++++++++++++++++++++++- docs/source/custom_generators.rst | 2 +- docs/source/introduction.rst | 2 +- docs/source/quickstart.rst | 6 +- tests/test_create.py | 19 +- tests/test_dump.py | 23 +- tests/test_functional.py | 28 +- tests/test_interactive.py | 627 ++++-- tests/test_main.py | 71 +- tests/test_make.py | 52 +- tests/test_providers.py | 3 +- tests/test_remove.py | 70 +- tests/test_unique_generator.py | 1 + tests/test_utils.py | 323 ++- tests/utils.py | 43 +- 26 files changed, 4627 insertions(+), 1099 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 17f471b5..56315a40 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -1,35 +1,34 @@ """Base table generator classes.""" -from abc import ABC, abstractmethod -from collections.abc import Callable -from dataclasses import dataclass import functools +import gzip import math -import numpy as np import os -from pathlib import Path import random +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path from typing import Any +import numpy as np import yaml -import gzip from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.schema import Table from datafaker.utils import ( + MAKE_VOCAB_PROGRESS_REPORT_EVERY, logger, stream_yaml, - MAKE_VOCAB_PROGRESS_REPORT_EVERY, table_row_count, ) + @functools.cache def zipf_weights(size): - total = sum(map(lambda n: 1/n, range(1, size + 1))) - return [ - 1 / (n * total) - for n in range(1, size + 1) - ] + total = sum(map(lambda n: 1 / n, range(1, size + 1))) + return [1 / (n * total) for n in range(1, size + 1)] + def merge_with_constants(xs: list, constants_at: dict[int, any]): """ @@ -118,10 +117,7 @@ def multivariate_normal_np(self, cov): rank = int(cov["rank"]) if rank == 0: return np.empty(shape=(0,)) - mean = [ - float(cov[f"m{i}"]) - for i in range(rank) - ] + mean = [float(cov[f"m{i}"]) for i in range(rank)] covs = [ [ float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"]) @@ -138,7 +134,9 @@ def _select_group(self, alts: list[dict[str, any]]): total = 0 for alt in alts: if alt["count"] < 0: - logger.warning("Alternative count is %d, but should not be negative", alt["count"]) + logger.warning( + "Alternative count is %d, but should not be negative", alt["count"] + ) else: total += alt["count"] if total == 0: @@ -218,7 +216,9 @@ def _check_generator_name(self, name: str) -> None: if name not in self.PERMITTED_SUBGENS: raise Exception("%s is not a permitted generator", name) - def alternatives(self, alternative_configs: list[dict[str, any]], counts: list[int] | None): + def alternatives( + self, alternative_configs: list[dict[str, any]], counts: list[int] | None + ): """ A generator that picks between other generators. @@ -245,7 +245,9 @@ def alternatives(self, alternative_configs: list[dict[str, any]], counts: list[i self._check_generator_name(name) return getattr(self, name)(**alt["params"]) - def with_constants_at(self, constants_at: list[int], subgen: str, params: dict[str, any]): + def with_constants_at( + self, constants_at: list[int], subgen: str, params: dict[str, any] + ): if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -257,7 +259,7 @@ def with_constants_at(self, constants_at: list[int], subgen: str, params: dict[s return list(merge_with_constants(subout, constants_at)) def truncated_string(self, subgen_fn, params, length): - """ Calls ``subgen_fn(**params)`` and truncates the results to ``length``. """ + """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: return None @@ -288,7 +290,9 @@ class FileUploader: table: Table - def _load_existing_file(self, connection: Connection, file_size: int, opener: Callable[[], Any]) -> None: + def _load_existing_file( + self, connection: Connection, file_size: int, opener: Callable[[], Any] + ) -> None: count = 0 with opener() as fh: rows = stream_yaml(fh) @@ -305,7 +309,7 @@ def _load_existing_file(self, connection: Connection, file_size: int, opener: Ca 100 * fh.tell() / file_size, ) - def load(self, connection: Connection, base_path: Path=Path(".")) -> None: + def load(self, connection: Connection, base_path: Path = Path(".")) -> None: """Load the data from file.""" yaml_file = base_path / Path(self.table.fullname + ".yaml") if yaml_file.exists(): @@ -318,7 +322,10 @@ def load(self, connection: Connection, base_path: Path=Path(".")) -> None: logger.warning("File %s not found. Skipping...", yaml_file) return if 0 < table_row_count(self.table, connection): - logger.warning("Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", self.table.name) + logger.warning( + "Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", + self.table.name, + ) return try: file_size = os.path.getsize(yaml_file) @@ -331,6 +338,7 @@ def load(self, connection: Connection, base_path: Path=Path(".")) -> None: "Error inserting rows into table %s: %s", self.table.fullname, e ) + class ColumnPresence: def sampled(self, patterns): total = 0 diff --git a/datafaker/create.py b/datafaker/create.py index d84eadb6..f5228762 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,17 +1,14 @@ """Functions and classes to create and populate the target database.""" -from collections import Counter import pathlib import random +from collections import Counter from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple from sqlalchemy import Connection, insert, inspect from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from sqlalchemy.schema import ( - CreateSchema, - MetaData, - Table, -) +from sqlalchemy.schema import CreateSchema, MetaData, Table + from datafaker.base import FileUploader, TableGenerator from datafaker.settings import get_settings from datafaker.utils import ( @@ -56,11 +53,11 @@ def create_db_vocab( metadata: MetaData, meta_dict: dict[str, Any], config: Mapping, - base_path: pathlib.Path | None=pathlib.Path(".") + base_path: pathlib.Path | None = pathlib.Path("."), ) -> int: """ Load vocabulary tables from files. - + arguments: metadata: The schema of the database meta_dict: The simple description of the schema from --orm-file @@ -85,7 +82,7 @@ def create_db_vocab( uploader = FileUploader(table=vocab_table) with Session(dst_engine) as session: session.begin() - uploader.load(session.connection(), base_path = base_path) + uploader.load(session.connection(), base_path=base_path) session.commit() tables_loaded.append(vocab_table_name) except IntegrityError: @@ -128,9 +125,7 @@ def create_db_data_into( db_dsn: str, schema_name: str | None, ) -> RowCounts: - dst_engine = get_sync_engine( - create_db_engine(db_dsn, schema_name=schema_name) - ) + dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) row_counts: Counter[str] = Counter() with dst_engine.connect() as dst_conn: @@ -145,7 +140,8 @@ def create_db_data_into( class StoryIterator: - def __init__(self, + def __init__( + self, stories: Iterable[tuple[str, Story]], table_dict: Mapping[str, Table], table_generator_dict: Mapping[str, TableGenerator], @@ -219,7 +215,9 @@ def next(self) -> None: self._table_name, self._provided_values = next(self._story) return else: - self._table_name, self._provided_values = self._story.send(self._final_values) + self._table_name, self._provided_values = self._story.send( + self._final_values + ) return except StopIteration: try: @@ -274,7 +272,9 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values(table_generator(dst_conn, random.random)) + stmt = insert(table).values( + table_generator(dst_conn, random.random) + ) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() diff --git a/datafaker/dump.py b/datafaker/dump.py index c4d2b24f..36ca046c 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,14 +1,11 @@ import csv import io + import sqlalchemy from sqlalchemy.schema import MetaData from datafaker.settings import get_settings -from datafaker.utils import ( - create_db_engine, - get_sync_engine, - logger, -) +from datafaker.utils import create_db_engine, get_sync_engine, logger def _make_csv_writer(file): @@ -16,14 +13,14 @@ def _make_csv_writer(file): def dump_db_tables( - metadata: MetaData, - dsn: str, - schema: str | None, - table_name: str, - file: io.TextIOBase + metadata: MetaData, + dsn: str, + schema: str | None, + table_name: str, + file: io.TextIOBase, ) -> None: - """ Output the table as CSV. """ - if table_name not in metadata.tables: + """Output the table as CSV.""" + if table_name not in metadata.tables: logger.error("%s is not a table described in the ORM file", table_name) return table = metadata.tables[table_name] diff --git a/datafaker/generators.py b/datafaker/generators.py index dd6e8187..1ea0a8f5 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -2,20 +2,21 @@ Generator factories for making generators for single columns. """ +import decimal +import math +import re from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass -import decimal from functools import lru_cache from itertools import chain, combinations -import math +from typing import Callable, Iterable, TypeVar + import mimesis import mimesis.locales -import re import sqlalchemy -from sqlalchemy import Column, Engine, text, Connection, RowMapping, Sequence +from sqlalchemy import Column, Connection, Engine, RowMapping, Sequence, text from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time -from typing import Callable, Iterable, TypeVar from datafaker.base import DistributionGenerator from datafaker.utils import logger @@ -27,6 +28,7 @@ dist_gen = DistributionGenerator() generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + class Generator(ABC): """ Random data generator. @@ -40,9 +42,10 @@ class Generator(ABC): It also knows these summary statistics for the column it was instantiated on, and therefore knows how to generate fake data for that column. """ + @abstractmethod def function_name(self) -> str: - """ The name of the generator function to put into df.py. """ + """The name of the generator function to put into df.py.""" def name(self) -> str: """ @@ -60,7 +63,7 @@ def nominal_kwargs(self) -> dict[str, str]: The values will tend to be references to something in the src-stats.yaml file. For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will - provide the value stored in src-stats.yaml as + provide the value stored in src-stats.yaml as SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument to the generator function. """ @@ -130,6 +133,7 @@ class PredefinedGenerator(Generator): """ Generator built from an existing config.yaml. """ + SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') @@ -152,13 +156,22 @@ def _get_src_stats_mentioned(self, val) -> set[str]: return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) return set() - def __init__(self, table_name: str, generator_object: Mapping[str, any], config: Mapping[str, any]): + def __init__( + self, + table_name: str, + generator_object: Mapping[str, any], + config: Mapping[str, any], + ): """ Initialise a generator from a config.yaml. :param config: The entire configuration. :param generator_object: The part of the configuration at tables.*.row_generators """ - logger.debug("Creating a PredefinedGenerator %s from table %s", generator_object["name"], table_name) + logger.debug( + "Creating a PredefinedGenerator %s from table %s", + generator_object["name"], + table_name, + ) self._table_name = table_name self._name: str = generator_object["name"] self._kwn: dict[str, str] = generator_object.get("kwargs", {}) @@ -170,7 +183,9 @@ def __init__(self, table_name: str, generator_object: Mapping[str, any], config: for sstat in config.get("src-stats", []): name: str = sstat["name"] dpq = sstat.get("dp-query", None) - query = sstat.get("query", dpq) #... should we really be combining query and dp-query? + query = sstat.get( + "query", dpq + ) # ... should we really be combining query and dp-query? comments = sstat.get("comments", []) if name in self._src_stats_mentioned: logger.debug("Found a src-stats entry for %s", name) @@ -181,7 +196,7 @@ def __init__(self, table_name: str, generator_object: Mapping[str, any], config: # name is auto__{table_name}, so it's a select_aggregate, so we split up its clauses sacs = [ self.AS_CLAUSE_RE.match(clause) - for clause in sam.group(1).split(',') + for clause in sam.group(1).split(",") ] # Work out what select_aggregate_clauses this represents for sac in sacs: @@ -213,13 +228,13 @@ def custom_queries(self) -> dict[str, dict[str, str]]: def actual_kwargs(self) -> dict[str, any]: # Run the queries from nominal_kwargs - #... + # ... logger.error("PredefinedGenerator.actual_kwargs not implemented yet") return {} def generate_data(self, count) -> list[any]: # Call the function if we can. This could be tricky... - #... + # ... logger.error("PredefinedGenerator.generate_data not implemented yet") return [] @@ -227,7 +242,8 @@ def generate_data(self, count) -> list[any]: class GeneratorFactory(ABC): """ A factory for making generators appropriate for a database column. - """ + """ + @abstractmethod def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: """ @@ -240,13 +256,27 @@ class Buckets: Finds the real distribution of continuous data so that we can measure the fit of generators against it. """ - def __init__(self, engine: Engine, table_name: str, column_name: str, mean:float, stddev: float, count: int): + + def __init__( + self, + engine: Engine, + table_name: str, + column_name: str, + mean: float, + stddev: float, + count: int, + ): with engine.connect() as connection: - raw_buckets = connection.execute(text( - "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( - column=column_name, table=table_name, x=mean - 2 * stddev, w = stddev / 2 + raw_buckets = connection.execute( + text( + "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( + column=column_name, + table=table_name, + x=mean - 2 * stddev, + w=stddev / 2, + ) ) - )) + ) self.buckets = [0] * 10 for rb in raw_buckets: if rb.b is not None: @@ -259,7 +289,7 @@ def __init__(self, engine: Engine, table_name: str, column_name: str, mean:float def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): """ Construct a Buckets object. - + Calculates the mean and standard deviation of the values in the column specified and makes ten buckets, centered on the mean and each half a standard deviation wide (except for the end two that extend to @@ -268,10 +298,12 @@ def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): """ with engine.connect() as connection: result = connection.execute( - text("SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( - table=table_name, - column=column_name, - )) + text( + "SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( + table=table_name, + column=column_name, + ) + ) ).first() if result is None or result.stddev is None or result.count < 2: return None @@ -303,13 +335,14 @@ def fit_from_values(self, values: list[float]) -> float: x = self.mean - 2 * self.stddev w = self.stddev / 2 for v in values: - b = min(9, max(0, int((v - x)/w))) + b = min(9, max(0, int((v - x) / w))) buckets[b] += 1 return self.fit_from_counts(buckets) class MultiGeneratorFactory(GeneratorFactory): - """ A composite factory. """ + """A composite factory.""" + def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories @@ -336,27 +369,30 @@ def __init__( f = generic for part in function_name.split("."): if not hasattr(f, part): - raise Exception(f"Mimesis does not have a function {function_name}: {part} not found") + raise Exception( + f"Mimesis does not have a function {function_name}: {part} not found" + ) f = getattr(f, part) if not callable(f): - raise Exception(f"Mimesis object {function_name} is not a callable, so cannot be used as a generator") + raise Exception( + f"Mimesis object {function_name} is not a callable, so cannot be used as a generator" + ) self._name = "generic." + function_name self._generator_function = f + def function_name(self): return self._name + def generate_data(self, count): - return [ - self._generator_function() - for _ in range(count) - ] + return [self._generator_function() for _ in range(count)] class MimesisGenerator(MimesisGeneratorBase): def __init__( self, function_name: str, - value_fn: Callable[[any], float] | None=None, - buckets: Buckets | None=None, + value_fn: Callable[[any], float] | None = None, + buckets: Buckets | None = None, ): """ Generator from Mimesis. @@ -373,17 +409,18 @@ def __init__( return samples = self.generate_data(400) if value_fn: - samples = [ - value_fn(s) - for s in samples - ] + samples = [value_fn(s) for s in samples] self._fit = buckets.fit_from_values(samples) + def function_name(self): return self._name + def nominal_kwargs(self): return {} + def actual_kwargs(self): return {} + def fit(self, default=None): return default if self._fit is None else self._fit @@ -393,36 +430,46 @@ def __init__( self, function_name: str, length: int, - value_fn: Callable[[any], float] | None=None, - buckets: Buckets | None=None, + value_fn: Callable[[any], float] | None = None, + buckets: Buckets | None = None, ): self._length = length super().__init__(function_name, value_fn, buckets) + def function_name(self): return "dist_gen.truncated_string" + def name(self): return f"{self._name} [truncated to {self._length}]" + def nominal_kwargs(self): return { "subgen_fn": self._name, "params": {}, "length": self._length, } + def actual_kwargs(self): return { "subgen_fn": self._name, "params": {}, "length": self._length, } + def generate_data(self, count): - return [ - self._generator_function()[:self._length] - for _ in range(count) - ] + return [self._generator_function()[: self._length] for _ in range(count)] class MimesisDateTimeGenerator(MimesisGeneratorBase): - def __init__(self, column: Column, function_name: str, min_year: str, max_year: str, start: int, end: int): + def __init__( + self, + column: Column, + function_name: str, + min_year: str, + max_year: str, + start: int, + end: int, + ): """ :param column: The column to generate into :param function_name: The name of the mimesis function @@ -445,28 +492,35 @@ def make_singleton(_cls, column: Column, engine: Engine, function_name: str): min_year = f"MIN({extract_year})" with engine.connect() as connection: result = connection.execute( - text(f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}") + text( + f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" + ) ).first() if result is None or result.start is None or result.end is None: return [] - return [MimesisDateTimeGenerator( - column, - function_name, - min_year, - max_year, - int(result.start), - int(result.end), - )] + return [ + MimesisDateTimeGenerator( + column, + function_name, + min_year, + max_year, + int(result.start), + int(result.end), + ) + ] + def nominal_kwargs(self): return { "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', } + def actual_kwargs(self): return { "start": self._start, "end": self._end, } + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: return { f"{self._column.name}__start": { @@ -478,6 +532,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: "comment": f"Latest year found for column {self._column.name} in table {self._column.table.name}", }, } + def generate_data(self, count): return [ self._generator_function(start=self._start, end=self._end) @@ -496,6 +551,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return strings. """ + GENERATOR_NAMES = [ "address.calling_code", "address.city", @@ -529,6 +585,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.text", "text.word", ] + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -551,35 +608,48 @@ def get_generators(self, columns: list[Column], engine: Engine): fitness_fn = None length = column_type.length if length: - return list(map( - lambda gen: MimesisGeneratorTruncated(gen, length, fitness_fn, buckets), + return list( + map( + lambda gen: MimesisGeneratorTruncated( + gen, length, fitness_fn, buckets + ), + self.GENERATOR_NAMES, + ) + ) + return list( + map( + lambda gen: MimesisGenerator(gen, fitness_fn, buckets), self.GENERATOR_NAMES, - )) - return list(map( - lambda gen: MimesisGenerator(gen, fitness_fn, buckets), - self.GENERATOR_NAMES, - )) + ) + ) class MimesisFloatGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return floating point numbers. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] column = columns[0] if not isinstance(get_column_type(column), Numeric): return [] - return list(map(MimesisGenerator, [ - "person.height", - ])) + return list( + map( + MimesisGenerator, + [ + "person.height", + ], + ) + ) class MimesisDateGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return dates. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -594,6 +664,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return datetimes. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -601,13 +672,16 @@ def get_generators(self, columns: list[Column], engine: Engine): ct = get_column_type(column) if not isinstance(ct, DateTime): return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.datetime") + return MimesisDateTimeGenerator.make_singleton( + column, engine, "datetime.datetime" + ) class MimesisTimeGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return times. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -622,6 +696,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return integers. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -633,7 +708,7 @@ def get_generators(self, columns: list[Column], engine: Engine): def fit_from_buckets(xs: list[float], ys: list[float]): - sum_diff_squared = sum(map(lambda t, a: (t - a)*(t - a), xs, ys)) + sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) return sum_diff_squared / (count * count) @@ -644,11 +719,13 @@ def __init__(self, table_name: str, column_name: str, buckets: Buckets): self.table_name = table_name self.column_name = column_name self.buckets = buckets + def nominal_kwargs(self): return { "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', } + def actual_kwargs(self): if self.buckets is None: return {} @@ -656,6 +733,7 @@ def actual_kwargs(self): "mean": self.buckets.mean, "sd": self.buckets.stddev, } + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: clauses = super().select_aggregate_clauses() return { @@ -669,6 +747,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", }, } + def fit(self, default=None): if self.buckets is None: return default @@ -676,9 +755,22 @@ def fit(self, default=None): class GaussianGenerator(ContinuousDistributionGenerator): - expected_buckets = [0.0227, 0.0441, 0.0918, 0.1499, 0.1915, 0.1915, 0.1499, 0.0918, 0.0441, 0.0227] + expected_buckets = [ + 0.0227, + 0.0441, + 0.0918, + 0.1499, + 0.1915, + 0.1915, + 0.1499, + 0.0918, + 0.0441, + 0.0227, + ] + def function_name(self): return "dist_gen.normal" + def generate_data(self, count): return [ dist_gen.normal(self.buckets.mean, self.buckets.stddev) @@ -687,9 +779,22 @@ def generate_data(self, count): class UniformGenerator(ContinuousDistributionGenerator): - expected_buckets = [0, 0.06698, 0.14434, 0.14434, 0.14434, 0.14434, 0.14434, 0.14434, 0.06698, 0] + expected_buckets = [ + 0, + 0.06698, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.06698, + 0, + ] + def function_name(self): return "dist_gen.uniform_ms" + def generate_data(self, count): return [ dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) @@ -701,6 +806,7 @@ class ContinuousDistributionGeneratorFactory(GeneratorFactory): """ All generators that want an average and standard deviation. """ + def _get_generators_from_buckets( self, _engine: Engine, @@ -725,36 +831,59 @@ def get_generators(self, columns: list[Column], engine: Engine) -> list[Generato buckets = Buckets.make_buckets(engine, table_name, column_name) if buckets is None: return [] - return self._get_generators_from_buckets(engine, table_name, column_name, buckets) + return self._get_generators_from_buckets( + engine, table_name, column_name, buckets + ) class LogNormalGenerator(Generator): - #TODO: figure out the real buckets here (this was from a random sample in R) - expected_buckets = [0, 0, 0, 0.28627, 0.40607, 0.14937, 0.06735, 0.03492, 0.01918, 0.03684] - def __init__(self, table_name: str, column_name: str, buckets: Buckets, logmean: float, logstddev: float): + # TODO: figure out the real buckets here (this was from a random sample in R) + expected_buckets = [ + 0, + 0, + 0, + 0.28627, + 0.40607, + 0.14937, + 0.06735, + 0.03492, + 0.01918, + 0.03684, + ] + + def __init__( + self, + table_name: str, + column_name: str, + buckets: Buckets, + logmean: float, + logstddev: float, + ): super().__init__() self.table_name = table_name self.column_name = column_name self.buckets = buckets self.logmean = logmean self.logstddev = logstddev + def function_name(self): return "dist_gen.lognormal" + def generate_data(self, count): - return [ - dist_gen.lognormal(self.logmean, self.logstddev) - for _ in range(count) - ] + return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] + def nominal_kwargs(self): return { "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', } + def actual_kwargs(self): return { "logmean": self.logmean, "logsd": self.logstddev, } + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: clauses = super().select_aggregate_clauses() return { @@ -768,6 +897,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: "comment": f"Standard deviation of logs of {self.column_name} from table {self.table_name}", }, } + def fit(self, default=None): if self.buckets is None: return default @@ -778,6 +908,7 @@ class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorF """ All generators that want an average and standard deviation of log data. """ + def _get_generators_from_buckets( self, engine: Engine, @@ -787,10 +918,12 @@ def _get_generators_from_buckets( ) -> list[Generator]: with engine.connect() as connection: result = connection.execute( - text("SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( - table=table_name, - column=column_name, - )) + text( + "SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( + table=table_name, + column=column_name, + ) + ) ).first() if result is None or result.logstddev is None: return [] @@ -806,7 +939,7 @@ def _get_generators_from_buckets( def zipf_distribution(total, bins): - basic_dist = list(map(lambda n: 1/n, range(1, bins + 1))) + basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) bd_remaining = sum(basic_dist) for b in basic_dist: # yield b/bd_remaining of the `total` remaining @@ -821,14 +954,15 @@ def zipf_distribution(total, bins): class ChoiceGenerator(Generator): STORE_COUNTS = False + def __init__( self, table_name, column_name, values, counts, - sample_count = None, - suppress_count = 0, + sample_count=None, + suppress_count=0, ): super().__init__() self.table_name = table_name @@ -868,19 +1002,23 @@ def get_estimated_counts(counts): """ The counts that we would expect if this distribution was the correct one. """ + def nominal_kwargs(self): return { "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } + def name(self): n = super().name() if self._annotation is None: return n return f"{n} [{self._annotation}]" + def actual_kwargs(self): return { "a": self.values, } + def custom_queries(self) -> dict[str, dict[str, str]]: qs = super().custom_queries() return { @@ -888,20 +1026,23 @@ def custom_queries(self) -> dict[str, dict[str, str]]: f"auto__{self.table_name}__{self.column_name}": { "query": self._query, "comment": self._comment, - } + }, } + def fit(self, default=None): return default if self._fit is None else self._fit + class ZipfChoiceGenerator(ChoiceGenerator): def get_estimated_counts(self, counts): return list(zipf_distribution(sum(counts), len(counts))) + def function_name(self): return "dist_gen.zipf_choice" + def generate_data(self, count): return [ - dist_gen.zipf_choice(self.values, len(self.values)) - for _ in range(count) + dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) ] @@ -917,34 +1058,35 @@ def uniform_distribution(total, bins): class UniformChoiceGenerator(ChoiceGenerator): def get_estimated_counts(self, counts): return list(uniform_distribution(sum(counts), len(counts))) + def function_name(self): return "dist_gen.choice" + def generate_data(self, count): - return [ - dist_gen.choice(self.values) - for _ in range(count) - ] + return [dist_gen.choice(self.values) for _ in range(count)] class WeightedChoiceGenerator(ChoiceGenerator): STORE_COUNTS = True + def get_estimated_counts(self, counts): return counts + def function_name(self): return "dist_gen.weighted_choice" + def generate_data(self, count): - return [ - dist_gen.weighted_choice(self.values) - for _ in range(count) - ] + return [dist_gen.weighted_choice(self.values) for _ in range(count)] class ChoiceGeneratorFactory(GeneratorFactory): """ All generators that want an average and standard deviation. """ + SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -954,16 +1096,20 @@ def get_generators(self, columns: list[Column], engine: Engine): generators = [] with engine.connect() as connection: results = connection.execute( - text("SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( - table=table_name, - column=column_name, - limit=MAXIMUM_CHOICES+1, - )) + text( + "SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( + table=table_name, + column=column_name, + limit=MAXIMUM_CHOICES + 1, + ) + ) ) if results is not None and results.rowcount <= MAXIMUM_CHOICES: values = [] # The values found counts = [] # The number or each value - cvs: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" + cvs: list[ + dict[str, any] + ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f if c != 0: @@ -980,19 +1126,27 @@ def get_generators(self, columns: list[Column], engine: Engine): WeightedChoiceGenerator(table_name, column_name, cvs, counts), ] results = connection.execute( - text("SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( - table=table_name, - column=column_name, - sample_count=self.SAMPLE_COUNT, - )) + text( + "SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( + table=table_name, + column=column_name, + sample_count=self.SAMPLE_COUNT, + ) + ) ) if results is not None: values = [] # All values found counts = [] # The number or each value - cvs: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times + cvs: list[ + dict[str, any] + ] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = ( + [] + ) # All values found more than SUPPRESS_COUNT times counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" + cvs_not_suppressed: list[ + dict[str, any] + ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f if c != 0: @@ -1011,9 +1165,27 @@ def get_generators(self, columns: list[Column], engine: Engine): cvs_not_suppressed.append({"value": v, "count": c}) if counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT), - UniformChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT), - WeightedChoiceGenerator(table_name, column_name, cvs, counts, sample_count=self.SAMPLE_COUNT), + ZipfChoiceGenerator( + table_name, + column_name, + values, + counts, + sample_count=self.SAMPLE_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + values, + counts, + sample_count=self.SAMPLE_COUNT, + ), + WeightedChoiceGenerator( + table_name, + column_name, + cvs, + counts, + sample_count=self.SAMPLE_COUNT, + ), ] if counts_not_suppressed: generators += [ @@ -1050,12 +1222,16 @@ def __init__(self, value): super().__init__() self.value = value self.repr = repr(value) + def function_name(self) -> str: return "dist_gen.constant" + def nominal_kwargs(self) -> dict[str, str]: return {"value": self.repr} + def actual_kwargs(self) -> dict[str, any]: return {"value": self.value} + def generate_data(self, count) -> list[any]: return [self.value for _ in range(count)] @@ -1064,6 +1240,7 @@ class ConstantGeneratorFactory(GeneratorFactory): """ Just the null generator """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -1116,7 +1293,7 @@ def actual_kwargs(self) -> dict[str, any]: """ The kwargs (summary statistics) this generator is instantiated with. """ - return { "cov": self._covariates } + return {"cov": self._covariates} def generate_data(self, count) -> list[any]: """ @@ -1145,12 +1322,12 @@ def query( self, table: str, columns: list[Column], - predicates: list[str]=[], - group_by_clause: str="", - constant_clauses: str="", - constants: str="", - suppress_count: int=1, - sample_count: int | None=None, + predicates: list[str] = [], + group_by_clause: str = "", + constant_clauses: str = "", + constants: str = "", + suppress_count: int = 1, + sample_count: int | None = None, ) -> str: """ Gets a query for the basics for multivariate normal/lognormal parameters. @@ -1175,15 +1352,13 @@ def query( multiples = "".join( f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[:iy+1]) - ) - means = "".join( - f", _q.m{i}" for i in range(len(columns)) + for ix, colx in enumerate(columns[: iy + 1]) ) + means = "".join(f", _q.m{i}" for i in range(len(columns))) covs = "".join( f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" for iy in range(len(columns)) - for ix in range(iy+1) + for ix in range(iy + 1) ) if sample_count is None: subquery = table + where @@ -1211,21 +1386,21 @@ def get_generators(self, columns: list[Column], engine: Engine): query = self.query(table, columns) with engine.connect() as connection: try: - covariates = connection.execute(text( - query - )).mappings().first() + covariates = connection.execute(text(query)).mappings().first() except Exception as e: logger.debug("SQL query %s failed with error %s", query, e) return [] if not covariates or covariates["c0_0"] is None: return [] - return [MultivariateNormalGenerator( - table, - column_names, - query, - covariates, - self.function_name(), - )] + return [ + MultivariateNormalGenerator( + table, + column_names, + query, + covariates, + self.function_name(), + ) + ] class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): @@ -1283,7 +1458,9 @@ def comment(self) -> str: else: where = f" where {text_list(self.excluded_columns.values())}" if len(self.included_numeric) == 1: - return f"Mean and variance for column {self.included_numeric[0].name}{where}." + return ( + f"Mean and variance for column {self.included_numeric[0].name}{where}." + ) return ( "Means and covariate matrix for the columns " f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" @@ -1298,7 +1475,7 @@ class NullPartitionedNormalGenerator(Generator): Generates data that matches the source data in missingness, choice of non-numeric data and numeric data. - + For the numeric data to be generated, samples of rows for each combination of non-numeric values and missingness. If any such combination has only one line in the source data (or sample of @@ -1307,15 +1484,16 @@ class NullPartitionedNormalGenerator(Generator): (although if the data is all non-numeric values and nulls, single rows are used because no covariate matrix is required for this). """ + def __init__( self, query_name: str, partitions: dict[int, RowPartition], - function_name: str="grouped_multivariate_lognormal", - name_suffix: str | None=None, - partition_count_query: str | None=None, - partition_counts: Sequence[RowMapping] | None=None, - partition_count_comment: str | None=None, + function_name: str = "grouped_multivariate_lognormal", + name_suffix: str | None = None, + partition_count_query: str | None = None, + partition_counts: Sequence[RowMapping] | None = None, + partition_count_comment: str | None = None, ): self._query_name = query_name self._partitions = partitions @@ -1358,7 +1536,7 @@ def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) "constants_at": partition.constant_outputs, "subgen": f'"{self._function_name}"', "params": covariates, - } + }, } def _count_query_name(self): @@ -1407,7 +1585,7 @@ def _actual_kwargs_with_combinations(self, partition: RowPartition): "name": self._function_name, "params": { "covs": partition.covariates, - } + }, } return { "count": count, @@ -1418,7 +1596,7 @@ def _actual_kwargs_with_combinations(self, partition: RowPartition): "params": { "covs": partition.covariates, }, - } + }, } def actual_kwargs(self) -> dict[str, any]: @@ -1438,10 +1616,7 @@ def generate_data(self, count) -> list[any]: Generate 'count' random data points for this column. """ kwargs = self.actual_kwargs() - return [ - dist_gen.alternatives(**kwargs) - for _ in range(count) - ] + return [dist_gen.alternatives(**kwargs) for _ in range(count)] def fit(self, default=None) -> float | None: return default @@ -1449,11 +1624,11 @@ def fit(self, default=None) -> float | None: def is_numeric(col: Column) -> bool: ct = get_column_type(col) - return ( - isinstance(ct, Numeric) or isinstance(ct, Integer) - ) and not col.foreign_keys + return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys + + +T = TypeVar("T") -T = TypeVar('T') def powerset(input: Iterable[T]) -> Iterable[Iterable[T]]: """Returns a list of all sublists of""" @@ -1465,6 +1640,7 @@ class NullableColumn: """ A reference to a nullable column whose nullability is part of a partitioning. """ + column: Column # The bit (power of two) of the number of the partition in the partition sizes list bitmask: int @@ -1474,13 +1650,12 @@ class NullPatternPartition: """ The definition of a partition (in other words, what makes it not another partition) """ + def __init__( - self, - columns: Iterable[Column], - partition_nonnulls: Iterable[NullableColumn] + self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] ): self.index = sum(nc.bitmask for nc in partition_nonnulls) - nonnull_columns = { nc.column.name for nc in partition_nonnulls } + nonnull_columns = {nc.column.name for nc in partition_nonnulls} self.included_numeric: list[Column] = [] self.included_choice: dict[int, str] = {} self.group_by_clause = "" @@ -1535,17 +1710,21 @@ def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: out: list[NullableColumn] = [] for col in columns: if col.nullable: - out.append(NullableColumn( - column=col, - bitmask=2 ** len(out), - )) + out.append( + NullableColumn( + column=col, + bitmask=2 ** len(out), + ) + ) return out - def get_partition_count_query(self, ncs: list[NullableColumn], table: str, where: str | None=None) -> str: + def get_partition_count_query( + self, ncs: list[NullableColumn], table: str, where: str | None = None + ) -> str: """ Returns a SQL expression returning columns ``count`` and ``index``. - Each row returned represents one of the null pattern partitions. + Each row returned represents one of the null pattern partitions. ``index`` is the bitmask of all those nullable columns that are not null for this partition, and ``count`` is the total number of rows in this partition. """ @@ -1576,7 +1755,7 @@ def get_generators(self, columns: list[Column], engine: Engine): columns=partition_def.included_numeric, predicates=partition_def.predicates, group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, + constants=partition_def.constants, constant_clauses=partition_def.constant_clauses, ) row_partitions_maximal[partition_def.index] = RowPartition( @@ -1592,7 +1771,7 @@ def get_generators(self, columns: list[Column], engine: Engine): columns=partition_def.included_numeric, predicates=partition_def.predicates, group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, + constants=partition_def.constants, constant_clauses=partition_def.constant_clauses, suppress_count=self.SUPPRESS_COUNT, sample_count=self.SAMPLE_COUNT, @@ -1608,43 +1787,49 @@ def get_generators(self, columns: list[Column], engine: Engine): gens = [] try: with engine.connect() as connection: - partition_query_max = self.get_partition_count_query(nullable_columns, table) - partition_count_max_results = connection.execute( - text(partition_query_max) - ).mappings().fetchall() + partition_query_max = self.get_partition_count_query( + nullable_columns, table + ) + partition_count_max_results = ( + connection.execute(text(partition_query_max)).mappings().fetchall() + ) count_comment = f"Number of rows for each combination of the columns { {nc.column.name for nc in nullable_columns} } of the table {table} being null" if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - )) + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_maximal, + self.function_name(), + partition_count_query=partition_query_max, + partition_counts=partition_count_max_results, + partition_count_comment=count_comment, + ) + ) partition_query_ss = self.get_partition_count_query( nullable_columns, table, - where=f"WHERE {self.SUPPRESS_COUNT} < count" + where=f"WHERE {self.SUPPRESS_COUNT} < count", + ) + partition_count_ss_results = ( + connection.execute(text(partition_query_ss)).mappings().fetchall() ) - partition_count_ss_results = connection.execute( - text(partition_query_ss) - ).mappings().fetchall() if self._execute_partition_queries(connection, row_partitions_ss): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - )) + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_ss, + self.function_name(), + name_suffix="sampled and suppressed", + partition_count_query=partition_query_ss, + partition_counts=partition_count_ss_results, + partition_count_comment=count_comment, + ) + ) except sqlalchemy.exc.DatabaseError as exc: logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) return [] return gens - + def _execute_partition_queries( self, connection: Connection, @@ -1656,9 +1841,7 @@ def _execute_partition_queries( """ found_nonzero = False for rp in partitions.values(): - rp.covariates = connection.execute(text( - rp.query - )).mappings().fetchall() + rp.covariates = connection.execute(text(rp.query)).mappings().fetchall() if not rp.covariates or rp.covariates[0]["count"] is None: rp.covariates = [{"count": 0}] else: @@ -1682,19 +1865,21 @@ def query_var(self, column: str) -> str: @lru_cache(1) def everything_factory(): - return MultiGeneratorFactory([ - MimesisStringGeneratorFactory(), - MimesisIntegerGeneratorFactory(), - MimesisFloatGeneratorFactory(), - MimesisDateGeneratorFactory(), - MimesisDateTimeGeneratorFactory(), - MimesisTimeGeneratorFactory(), - ContinuousDistributionGeneratorFactory(), - ContinuousLogDistributionGeneratorFactory(), - ChoiceGeneratorFactory(), - ConstantGeneratorFactory(), - MultivariateNormalGeneratorFactory(), - MultivariateLogNormalGeneratorFactory(), - NullPartitionedNormalGeneratorFactory(), - NullPartitionedLogNormalGeneratorFactory(), - ]) + return MultiGeneratorFactory( + [ + MimesisStringGeneratorFactory(), + MimesisIntegerGeneratorFactory(), + MimesisFloatGeneratorFactory(), + MimesisDateGeneratorFactory(), + MimesisDateTimeGeneratorFactory(), + MimesisTimeGeneratorFactory(), + ContinuousDistributionGeneratorFactory(), + ContinuousLogDistributionGeneratorFactory(), + ChoiceGeneratorFactory(), + ConstantGeneratorFactory(), + MultivariateNormalGeneratorFactory(), + MultivariateLogNormalGeneratorFactory(), + NullPartitionedNormalGeneratorFactory(), + NullPartitionedLogNormalGeneratorFactory(), + ] + ) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 5e3fb899..4f5e8c55 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -10,7 +10,7 @@ import sqlalchemy from prettytable import PrettyTable -from sqlalchemy import Column, MetaData, Table, text, ForeignKey +from sqlalchemy import Column, ForeignKey, MetaData, Table, text from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( @@ -26,15 +26,18 @@ # See https://github.com/pyreadline3/pyreadline3/issues/37 try: import readline + if not hasattr(readline, "backend"): readline.backend = "readline" except: pass + def or_default(v, d): - """ Returns v if it isn't None, otherwise d. """ + """Returns v if it isn't None, otherwise d.""" return d if v is None else v + class TableType(Enum): GENERATE = "generate" IGNORE = "ignore" @@ -42,6 +45,7 @@ class TableType(Enum): PRIVATE = "private" EMPTY = "empty" + TYPE_LETTER = { TableType.GENERATE: "G", TableType.IGNORE: "I", @@ -58,6 +62,7 @@ class TableType(Enum): TableType.EMPTY: "(table: {} (empty))", } + @dataclass class TableEntry: name: str # name of the table @@ -67,15 +72,19 @@ class AskSaveCmd(cmd.Cmd): intro = "Do you want to save this configuration?" prompt = "(yes/no/cancel) " file = None + def __init__(self): super().__init__() self.result = "" + def do_yes(self, _arg): self.result = "yes" return True + def do_no(self, _arg): self.result = "no" return True + def do_cancel(self, _arg): self.result = "cancel" return True @@ -98,7 +107,9 @@ class DbCmd(ABC, cmd.Cmd): def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: ... - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__() self.config = config self.metadata = metadata @@ -115,30 +126,33 @@ def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Ma self.table_entries.append(entry) self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): self.engine.dispose() def print(self, text: str, *args, **kwargs): print(text.format(*args, **kwargs)) + def print_table(self, headings: list[str], rows: list[list[str]]): output = PrettyTable() output.field_names = headings for row in rows: output.add_row(row) print(output) + def print_table_by_columns(self, columns: dict[str, list[str]]): output = PrettyTable() row_count = max([len(col) for col in columns.values()]) for field_name, data in columns.items(): output.add_column(field_name, data + [None] * (row_count - len(data))) print(output) + def print_results(self, result): - self.print_table( - list(result.keys()), - [list(row) for row in result.all()] - ) + self.print_table(list(result.keys()), [list(row) for row in result.all()]) + def ask_save(self): ask = AskSaveCmd() ask.cmdloop() @@ -150,39 +164,51 @@ def set_table_index(self, index) -> bool: self.set_prompt() return True return False + def next_table(self, report="No more tables"): if not self.set_table_index(self.table_index + 1): self.print(report) return False return True + def table_name(self): return self.table_entries[self.table_index].name + def table_metadata(self) -> Table: return self.metadata.tables[self.table_name()] + def get_column_names(self) -> list[str]: - return [ - col.name - for col in self.table_metadata().columns - ] + return [col.name for col in self.table_metadata().columns] + def report_columns(self): - self.print_table(["name", "type", "primary", "nullable", "foreign key"], [ - [name, str(col.type), col.primary_key, col.nullable, ", ".join( - [fk_column_name(fk) for fk in col.foreign_keys] - )] - for name, col in self.table_metadata().columns.items() - ]) + self.print_table( + ["name", "type", "primary", "nullable", "foreign key"], + [ + [ + name, + str(col.type), + col.primary_key, + col.nullable, + ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), + ] + for name, col in self.table_metadata().columns.items() + ], + ) + def get_table_config(self, table_name: str) -> dict[str, any]: ts = self.config.get("tables", None) if type(ts) is not dict: return {} t = ts.get(table_name) return t if type(t) is dict else {} + def set_table_config(self, table_name: str, config: dict[str, any]): ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} return ts[table_name] = config + def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: src_stats = self.config.get("src-stats", []) new_src_stats = [] @@ -191,6 +217,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: new_src_stats.append(stat) self.config["src-stats"] = new_src_stats return new_src_stats + def get_nonnull_columns(self, table_name: str): metadata_table = self.metadata.tables[table_name] return [ @@ -198,52 +225,59 @@ def get_nonnull_columns(self, table_name: str): for name, column in metadata_table.columns.items() if column.nullable ] + def find_entry_index_by_table_name(self, table_name) -> int | None: return next( - (i for i,entry in enumerate(self.table_entries) if entry.name == table_name), + ( + i + for i, entry in enumerate(self.table_entries) + if entry.name == table_name + ), None, ) + def find_entry_by_table_name(self, table_name) -> TableEntry | None: for e in self.table_entries: if e.name == table_name: return e return None + def do_counts(self, _arg): "Report the column names with the counts of nulls in them" if len(self.table_entries) <= self.table_index: return table_name = self.table_name() nonnull_columns = self.get_nonnull_columns(table_name) - colcounts = [ - ", COUNT({0}) AS {0}".format(nnc) - for nnc in nonnull_columns - ] + colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] with self.engine.connect() as connection: result = connection.execute( - text("SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( - table=table_name, - colcounts="".join(colcounts), - )) + text( + "SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( + table=table_name, + colcounts="".join(colcounts), + ) + ) ).first() if result is None: self.print("Could not count rows in table {0}", table_name) return row_count = result.row_count self.print(self.ROW_COUNT_MSG, row_count) - self.print_table(["Column", "NULL count"], [ - [name, row_count - count] - for name, count in result._mapping.items() - if name != "row_count" - ]) + self.print_table( + ["Column", "NULL count"], + [ + [name, row_count - count] + for name, count in result._mapping.items() + if name != "row_count" + ], + ) def do_select(self, arg): "Run a select query over the database and show the first 50 results" MAX_SELECT_ROWS = 50 with self.engine.connect() as connection: try: - result = connection.execute( - text("SELECT " + arg) - ) + result = connection.execute(text("SELECT " + arg)) except sqlalchemy.exc.DatabaseError as exc: self.print("Failed to execute: {}", exc) return @@ -252,10 +286,7 @@ def do_select(self, arg): if 50 < row_count: self.print("Showing the first {} rows", MAX_SELECT_ROWS) fields = list(result.keys()) - rows = [ - row._tuple() - for row in result.fetchmany(MAX_SELECT_ROWS) - ] + rows = [row._tuple() for row in result.fetchmany(MAX_SELECT_ROWS)] self.print_table(fields, rows) def do_peek(self, arg: str): @@ -274,30 +305,25 @@ def do_peek(self, arg: str): nonnulls = [cn + " IS NOT NULL" for cn in col_names] with self.engine.connect() as connection: query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( - cols=",".join(col_names), - table=table_name, - where="WHERE" if nonnulls else "", - nonnull=" OR ".join(nonnulls), - max=MAX_PEEK_ROWS, - ) + cols=",".join(col_names), + table=table_name, + where="WHERE" if nonnulls else "", + nonnull=" OR ".join(nonnulls), + max=MAX_PEEK_ROWS, + ) try: result = connection.execute(text(query)) except Exception as exc: self.print(f'SQL query "{query}" caused exception {exc}') return - rows = [ - row._tuple() - for row in result.fetchmany(MAX_PEEK_ROWS) - ] + rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] self.print_table(list(result.keys()), rows) def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int): if len(self.table_entries) <= self.table_index: return [] return [ - col - for col in self.table_metadata().columns.keys() - if col.startswith(text) + col for col in self.table_metadata().columns.keys() if col.startswith(text) ] @@ -306,6 +332,7 @@ class TableCmdTableEntry(TableEntry): old_type: TableType new_type: TableType + class TableCmd(DbCmd): intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" doc_leader = """Use the commands 'ignore', 'vocabulary', @@ -316,10 +343,16 @@ class TableCmd(DbCmd): to exit this program.""" prompt = "(tableconf) " file = None - WARNING_TEXT_VOCAB_TO_NON_VOCAB = "Vocabulary table {0} references non-vocabulary table {1}" - WARNING_TEXT_NON_EMPTY_TO_EMPTY = "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( + "Vocabulary table {0} references non-vocabulary table {1}" + ) + WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( + "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + ) WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" - WARNING_TEXT_POTENTIAL_PROBLEMS = "NOTE: The following table types might cause problems later:" + WARNING_TEXT_POTENTIAL_PROBLEMS = ( + "NOTE: The following table types might cause problems later:" + ) NOTE_TEXT_NO_CHANGES = "You have made no changes." NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" @@ -334,7 +367,9 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: return TableCmdTableEntry(name, TableType.EMPTY, TableType.EMPTY) return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -344,16 +379,21 @@ def set_prompt(self): self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) else: self.prompt = "(table) " + def set_type(self, t_type: TableType): if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type + def _copy_entries(self) -> None: for entry in self.table_entries: entry: TableCmdTableEntry if entry.old_type != entry.new_type: table = self.get_table_config(entry.name) - if entry.old_type == TableType.EMPTY and table.get("num_rows_per_pass", 1) == 0: + if ( + entry.old_type == TableType.EMPTY + and table.get("num_rows_per_pass", 1) == 0 + ): table["num_rows_per_pass"] = 1 if entry.new_type == TableType.IGNORE: table["ignore"] = True @@ -381,13 +421,11 @@ def _copy_entries(self) -> None: def _get_referenced_tables(self, from_table_name: str) -> set[str]: from_meta = self.metadata.tables[from_table_name] return { - fk.column.table.name - for col in from_meta.columns - for fk in col.foreign_keys + fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys } def _sanity_check_failures(self) -> list[tuple[str, str, str]]: - """ Find tables that reference each other that should not given their types. """ + """Find tables that reference each other that should not given their types.""" failures = [] for from_entry in self.table_entries: from_entry: TableCmdTableEntry @@ -396,16 +434,21 @@ def _sanity_check_failures(self) -> list[tuple[str, str, str]]: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: to_entry = self.find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type != TableType.VOCABULARY: - failures.append(( - self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - from_entry.name, - to_entry.name, - )) + if ( + to_entry is not None + and to_entry.new_type != TableType.VOCABULARY + ): + failures.append( + ( + self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + from_entry.name, + to_entry.name, + ) + ) return failures def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: - """ Find tables that reference each other that might cause problems given their types. """ + """Find tables that reference each other that might cause problems given their types.""" warnings = [] for from_entry in self.table_entries: from_entry: TableCmdTableEntry @@ -414,15 +457,19 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: to_entry = self.find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type in {TableType.EMPTY, TableType.IGNORE}: - warnings.append(( - self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - from_entry.name, - to_entry.name, - )) + if to_entry is not None and to_entry.new_type in { + TableType.EMPTY, + TableType.IGNORE, + }: + warnings.append( + ( + self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + from_entry.name, + to_entry.name, + ) + ) return warnings - def do_quit(self, _arg): "Check the updates, save them if desired and quit the configurer." count = 0 @@ -440,12 +487,12 @@ def do_quit(self, _arg): failures = self._sanity_check_failures() if failures: self.print(self.WARNING_TEXT_PROBLEMS_EXIST) - for (text, from_t, to_t) in failures: + for text, from_t, to_t in failures: self.print(text, from_t, to_t) warnings = self._sanity_check_warnings() if warnings: self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) - for (text, from_t, to_t) in warnings: + for text, from_t, to_t in warnings: self.print(text, from_t, to_t) reply = self.ask_save() if reply == "yes": @@ -454,6 +501,7 @@ def do_quit(self, _arg): if reply == "no": return True return False + def do_tables(self, _arg): "list the tables with their types" for entry in self.table_entries: @@ -461,6 +509,7 @@ def do_tables(self, _arg): new = entry.new_type becomes = " " if old == new else "->" + TYPE_LETTER[new] self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) + def do_next(self, arg): "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: @@ -472,44 +521,51 @@ def do_next(self, arg): self.set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) + def complete_next(self, text, line, begidx, endidx): return [ - entry.name - for entry in self.table_entries - if entry.name.startswith(text) + entry.name for entry in self.table_entries if entry.name.startswith(text) ] + def do_previous(self, _arg): "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) + def do_ignore(self, _arg): "Set the current table as ignored, and go to the next table" self.set_type(TableType.IGNORE) self.print("Table {} set as ignored", self.table_name()) self.next_table() + def do_vocabulary(self, _arg): "Set the current table as a vocabulary table, and go to the next table" self.set_type(TableType.VOCABULARY) self.print("Table {} set to be a vocabulary table", self.table_name()) self.next_table() + def do_private(self, _arg): "Set the current table as a primary private table (such as the table of patients)" self.set_type(TableType.PRIVATE) self.print("Table {} set to be a primary private table", self.table_name()) self.next_table() + def do_generate(self, _arg): "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() + def do_empty(self, _arg): "Set the current table as empty; no generators will be run for it" self.set_type(TableType.EMPTY) self.print("Table {} empty", self.table_name()) self.next_table() + def do_columns(self, _arg): "Report the column names and metadata" self.report_columns() + def do_data(self, arg: str): """ Report some data. @@ -549,15 +605,13 @@ def do_data(self, arg: str): if number is None: number = 48 self.print_column_data(column, number, min_length) + def complete_data(self, text, line, begidx, endidx): - previous_parts = line[:begidx - 1].split() + previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] table_metadata = self.table_metadata() - return [ - k for k in table_metadata.columns.keys() - if k.startswith(text) - ] + return [k for k in table_metadata.columns.keys() if k.startswith(text)] def print_column_data(self, column: str, count: int, min_length: int): where = f"WHERE {column} IS NOT NULL" @@ -568,22 +622,26 @@ def print_column_data(self, column: str, count: int, min_length: int): ) with self.engine.connect() as connection: result = connection.execute( - text("SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - column=column, - count=count, - where=where, - )) + text( + "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( + table=self.table_name(), + column=column, + count=count, + where=where, + ) + ) ) self.columnize([str(x[0]) for x in result.all()]) def print_row_data(self, count: int): with self.engine.connect() as connection: result = connection.execute( - text("SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - count=count, - )) + text( + "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( + table=self.table_name(), + count=count, + ) + ) ) if result is None: self.print("No rows in this table!") @@ -591,7 +649,9 @@ def print_row_data(self, count: int): self.print_results(result) -def update_config_tables(src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): +def update_config_tables( + src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping +): with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() return tc.config @@ -599,8 +659,8 @@ def update_config_tables(src_dsn: str, src_schema: str, metadata: MetaData, conf @dataclass class MissingnessType: - SAMPLED="column_presence.sampled" - SAMPLED_QUERY=( + SAMPLED = "column_presence.sampled" + SAMPLED_QUERY = ( "SELECT COUNT(*) AS row_count, {result_names} FROM " "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" " AS __t GROUP BY {result_names}" @@ -609,16 +669,13 @@ class MissingnessType: query: str comment: str columns: list[str] + @classmethod def sampled_query(cls, table, count, column_names) -> str: - result_names = ", ".join([ - "{0}__is_null".format(c) - for c in column_names - ]) - column_is_nulls = ", ".join([ - "{0} IS NULL AS {0}__is_null".format(c) - for c in column_names - ]) + result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) + column_is_nulls = ", ".join( + ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] + ) return cls.SAMPLED_QUERY.format( result_names=result_names, column_is_nulls=column_is_nulls, @@ -644,8 +701,10 @@ class MissingnessCmd(DbCmd): file = None PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') - def find_missingness_query(self, missingness_generator: Mapping) -> tuple[str | None, str | None] | None: - """ Find query and comment from src-stats for the passed missingness generator. """ + def find_missingness_query( + self, missingness_generator: Mapping + ) -> tuple[str | None, str | None] | None: + """Find query and comment from src-stats for the passed missingness generator.""" kwargs = missingness_generator.get("kwargs", {}) patterns = kwargs.get("patterns", "") pattern_match = self.PATTERN_RE.match(patterns) @@ -655,6 +714,7 @@ def find_missingness_query(self, missingness_generator: Mapping) -> tuple[str | if src_stat.get("name") == key: return (src_stat.get("query", None), src_stat.get("comment", None)) return None + def make_table_entry(self, name: str, table: Mapping) -> TableEntry: if table.get("ignore", False): return None @@ -695,7 +755,9 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: new_type=old, ) - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -709,10 +771,12 @@ def set_prompt(self): self.prompt = "(missingness for {0}: {1}) ".format(entry.name, nt.name) else: self.prompt = "(missingness) " + def set_type(self, t_type: TableType): if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type + def _copy_entries(self) -> None: src_stats = self._remove_prefix_src_stats("missing_auto__") for entry in self.table_entries: @@ -722,16 +786,26 @@ def _copy_entries(self) -> None: table.pop("missingness_generators", None) else: src_stat_key = "missing_auto__{0}__0".format(entry.name) - table["missingness_generators"] = [{ - "name": entry.new_type.name, - "kwargs": {"patterns": 'SRC_STATS["{0}"]["results"]'.format(src_stat_key)}, - "columns": entry.new_type.columns, - }] - src_stats.append({ - "name": src_stat_key, - "query": entry.new_type.query, - "comments": [] if entry.new_type.comment is None else [entry.new_type.comment], - }) + table["missingness_generators"] = [ + { + "name": entry.new_type.name, + "kwargs": { + "patterns": 'SRC_STATS["{0}"]["results"]'.format( + src_stat_key + ) + }, + "columns": entry.new_type.columns, + } + ] + src_stats.append( + { + "name": src_stat_key, + "query": entry.new_type.query, + "comments": [] + if entry.new_type.comment is None + else [entry.new_type.comment], + } + ) self.set_table_config(entry.name, table) def do_quit(self, _arg): @@ -741,9 +815,17 @@ def do_quit(self, _arg): if entry.old_type != entry.new_type: count += 1 if entry.old_type is None: - self.print("Putting generator {0} on table {1}", entry.name, entry.new_type.name) + self.print( + "Putting generator {0} on table {1}", + entry.name, + entry.new_type.name, + ) elif entry.new_type is None: - self.print("Deleting generator {1} from table {0}", entry.name, entry.old_type.name) + self.print( + "Deleting generator {1} from table {0}", + entry.name, + entry.old_type.name, + ) else: self.print( "Changing {0} from {1} to {2}", @@ -760,6 +842,7 @@ def do_quit(self, _arg): if reply == "no": return True return False + def do_tables(self, arg): "list the tables with their types" for entry in self.table_entries: @@ -767,27 +850,32 @@ def do_tables(self, arg): new = "-" if entry.new_type is None else entry.new_type.name desc = new if old == new else "{0}->{1}".format(old, new) self.print("{0} {1}", entry.name, desc) + def do_next(self, arg): "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: # Find the index of the table called _arg, if any - index = next((i for i,entry in enumerate(self.table_entries) if entry.name == arg), None) + index = next( + (i for i, entry in enumerate(self.table_entries) if entry.name == arg), + None, + ) if index is None: self.print(self.ERROR_NO_SUCH_TABLE, arg) return self.set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) + def complete_next(self, text, line, begidx, endidx): return [ - entry.name - for entry in self.table_entries - if entry.name.startswith(text) + entry.name for entry in self.table_entries if entry.name.startswith(text) ] + def do_previous(self, _arg): "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) + def _set_type(self, name, query, comment): if len(self.table_entries) <= self.table_index: return @@ -798,11 +886,13 @@ def _set_type(self, name, query, comment): comment=comment, columns=self.get_nonnull_columns(entry.name), ) + def _set_none(self): if len(self.table_entries) <= self.table_index: return entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] entry.new_type = None + def do_sampled(self, arg: str): """ Set the current table missingness as 'sampled', and go to the next table. @@ -819,7 +909,10 @@ def do_sampled(self, arg: str): elif arg.isdecimal(): count = int(arg) else: - self.print("Error: sampled can be used alone or with an integer argument. {0} is not permitted", arg) + self.print( + "Error: sampled can be used alone or with an integer argument. {0} is not permitted", + arg, + ) return self._set_type( MissingnessType.SAMPLED, @@ -828,10 +921,11 @@ def do_sampled(self, arg: str): count, self.get_nonnull_columns(entry.name), ), - f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}" + f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}", ) self.print("Table {} set to sampled missingness", self.table_name()) self.next_table() + def do_none(self, _arg): "Set the current table to have no missingness, and go to the next table" self._set_none() @@ -839,7 +933,9 @@ def do_none(self, _arg): self.next_table() -def update_missingness(src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): +def update_missingness( + src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping +): with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() return mc.config @@ -850,11 +946,13 @@ class GeneratorInfo: columns: list[str] gen: Generator | None + @dataclass class GeneratorCmdTableEntry(TableEntry): old_generators: list[GeneratorInfo] new_generators: list[GeneratorInfo] + class GeneratorCmd(DbCmd): intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the @@ -883,7 +981,9 @@ class GeneratorCmd(DbCmd): ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" PROPOSE_NOTHING = "No proposed generators, sorry." - SRC_STAT_RE = re.compile(r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?') + SRC_STAT_RE = re.compile( + r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' + ) def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None: if table.get("ignore", False): @@ -902,35 +1002,47 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None gen_name = rg.get("name", None) if gen_name: ca = rg.get("columns_assigned", []) - collist: list[str] = [ca] if isinstance(ca, str) else [str(c) for c in ca] + collist: list[str] = ( + [ca] if isinstance(ca, str) else [str(c) for c in ca] + ) colset: set[str] = set(collist) for unknown in colset - column_set: logger.warning( "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, gen_name, unknown + table_name, + gen_name, + unknown, ) for mult in columns_assigned_so_far & colset: logger.warning( - "table '%s' has column '%s' assigned to multiple times", table_name, mult + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, ) actual_collist = [c for c in collist if c in columns] if actual_collist: gen = PredefinedGenerator(table, rg, self.config) - new_generator_infos.append(GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - )) - old_generator_infos.append(GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - )) + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) + old_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) columns_assigned_so_far |= colset for colname in columns: if colname not in columns_assigned_so_far: - new_generator_infos.append(GeneratorInfo( - columns=[colname], - gen=None, - )) + new_generator_infos.append( + GeneratorInfo( + columns=[colname], + gen=None, + ) + ) if len(new_generator_infos) == 0: return None return GeneratorCmdTableEntry( @@ -939,7 +1051,9 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None new_generators=new_generator_infos, ) - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__(src_dsn, src_schema, metadata, config) self.generator_index = 0 self.generators_valid_columns = None @@ -957,7 +1071,10 @@ def previous_table(self): if ret: table = self.get_table() if table is None: - self.print("Internal error! table {0} does not have any generators!", self.table_index) + self.print( + "Internal error! table {0} does not have any generators!", + self.table_index, + ) return False self.generator_index = len(table.new_generators) - 1 else: @@ -985,10 +1102,7 @@ def column_metadata(self) -> list[Column]: table = self.table_metadata() if table is None: return [] - return [ - table.columns[name] - for name in self.get_column_names() - ] + return [table.columns[name] for name in self.get_column_names()] def set_prompt(self): (table_name, gen_info) = self.get_table_and_generator() @@ -1000,8 +1114,7 @@ def set_prompt(self): return table = self.table_metadata() columns = [ - c + "[pk]" if table.columns[c].primary_key else c - for c in gen_info.columns + c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns ] gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " @@ -1020,11 +1133,15 @@ def _copy_entries(self) -> None: new_gens.append(generator.gen) cqs = generator.gen.custom_queries() for cq_key, cq in cqs.items(): - src_stats.append({ - "name": cq_key, - "query": cq["query"], - "comments": [cq["comment"]] if "comment" in cq and cq["comment"] else [], - }) + src_stats.append( + { + "name": cq_key, + "query": cq["query"], + "comments": [cq["comment"]] + if "comment" in cq and cq["comment"] + else [], + } + ) rg = { "name": generator.gen.function_name(), "columns_assigned": generator.columns, @@ -1035,16 +1152,18 @@ def _copy_entries(self) -> None: rgs.append(rg) aq = self._get_aggregate_query(new_gens, entry.name) if aq: - src_stats.append({ - "name": f"auto__{entry.name}", - "query": aq, - "comments": [ - q["comment"] - for gen in new_gens - for q in gen.select_aggregate_clauses().values() - if "comment" in q and q["comment"] is not None - ], - }) + src_stats.append( + { + "name": f"auto__{entry.name}", + "query": aq, + "comments": [ + q["comment"] + for gen in new_gens + for q in gen.select_aggregate_clauses().values() + if "comment" in q and q["comment"] is not None + ], + } + ) table_config = self.get_table_config(entry.name) if rgs: table_config["row_generators"] = rgs @@ -1053,8 +1172,10 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table_config) self.config["src-stats"] = src_stats - def _find_old_generator(self, entry: GeneratorCmdTableEntry, columns) -> Generator | None: - """ Find any generator that previously assigned to these exact same columns. """ + def _find_old_generator( + self, entry: GeneratorCmdTableEntry, columns + ) -> Generator | None: + """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) for gen in entry.old_generators: if frozenset(gen.columns) == fc: @@ -1139,12 +1260,18 @@ def do_info(self, _arg): "nullable" if cm.nullable else "not nullable", ) if cm.primary_key: - self.print("It is a primary key, which usually does not need a generator (it will auto-increment)") + self.print( + "It is a primary key, which usually does not need a generator (it will auto-increment)" + ) if cm.foreign_keys: fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] - self.print("It is a foreign key referencing column {0}", ", ".join(fk_names)) + self.print( + "It is a foreign key referencing column {0}", ", ".join(fk_names) + ) if len(fk_names) == 1 and not cm.primary_key: - self.print("You do not need a generator if you just want a uniform choice over the referenced table's rows") + self.print( + "You do not need a generator if you just want a uniform choice over the referenced table's rows" + ) def _get_table_index(self, table_name: str) -> int | None: for n, entry in enumerate(self.table_entries): @@ -1196,7 +1323,7 @@ def do_next(self, arg): self._go_next() def do_n(self, arg): - """ Synonym for next """ + """Synonym for next""" self.do_next(arg) def complete_n(self, text: str, line: str, begidx: int, endidx: int): @@ -1248,7 +1375,7 @@ def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): return table_names + column_names def do_previous(self, _arg): - """ Go to the previous generator """ + """Go to the previous generator""" if self.generator_index == 0: self.previous_table() else: @@ -1256,11 +1383,14 @@ def do_previous(self, _arg): self.set_prompt() def do_b(self, arg): - """ Synonym for previous """ + """Synonym for previous""" self.do_previous(arg) def _generators_valid(self) -> bool: - return self.generators_valid_columns == (self.table_index, self.get_column_names()) + return self.generators_valid_columns == ( + self.table_index, + self.get_column_names(), + ) def _get_generator_proposals(self) -> list[Generator]: if not self._generators_valid(): @@ -1270,7 +1400,10 @@ def _get_generator_proposals(self) -> list[Generator]: gens = everything_factory().get_generators(columns, self.engine) gens.sort(key=lambda g: g.fit(9999)) self.generators = gens - self.generators_valid_columns = (self.table_index, self.get_column_names().copy()) + self.generators_valid_columns = ( + self.table_index, + self.get_column_names().copy(), + ) return self.generators def _print_privacy(self): @@ -1316,7 +1449,7 @@ def do_compare(self, arg: str): self.print_table_by_columns(comparison) def do_c(self, arg): - """ Synonym for compare. """ + """Synonym for compare.""" self.do_compare(arg) def _print_values_queried(self, table_name: str, n: int, gen: Generator): @@ -1354,7 +1487,11 @@ def _print_custom_queries(self, gen: Generator) -> None: actual, ) for cq_key, cq in cqs.items(): - self.print("{0}; providing the following values: {1}", cq["query"], cq_key2args[cq_key]) + self.print( + "{0}; providing the following values: {1}", + cq["query"], + cq_key2args[cq_key], + ) def _get_custom_queries_from(self, out, nominal, actual): if type(nominal) is str: @@ -1375,7 +1512,9 @@ def _get_custom_queries_from(self, out, nominal, actual): if k in actual: self._get_custom_queries_from(out, v, actual[k]) - def _get_aggregate_query(self, gens: list[Generator], table_name: str) -> str | None: + def _get_aggregate_query( + self, gens: list[Generator], table_name: str + ) -> str | None: clauses = [ f'{q["clause"]} AS {n}' for gen in gens @@ -1394,7 +1533,7 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: return kwa = gen.actual_kwargs() vals = [] - src_stat2kwarg = { v: k for k, v in gen.nominal_kwargs().items() } + src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} for n in sacs.keys(): src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' if src_stat in src_stat2kwarg: @@ -1402,9 +1541,16 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: if ak in kwa: vals.append(kwa[ak]) else: - logger.warning("actual_kwargs for %s does not report %s", gen.name(), ak) + logger.warning( + "actual_kwargs for %s does not report %s", gen.name(), ak + ) else: - logger.warning('nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', gen.name(), table_name, n) + logger.warning( + 'nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', + gen.name(), + table_name, + n, + ) select_q = self._get_aggregate_query([gen], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) @@ -1414,12 +1560,11 @@ def _get_column_data(self, count: int, to_str=repr): pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) with self.engine.connect() as connection: result = connection.execute( - text(f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}") + text( + f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" + ) ) - return [ - [to_str(x) for x in xs] - for xs in result.all() - ] + return [[to_str(x) for x in xs] for xs in result.all()] def do_propose(self, _arg): """ @@ -1433,10 +1578,7 @@ def do_propose(self, _arg): gens = self._get_generator_proposals() sample = self._get_column_data(limit) if sample: - rep = [ - x[0] if len(x) == 1 else ",".join(x) - for x in sample - ] + rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) else: self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) @@ -1455,11 +1597,11 @@ def do_propose(self, _arg): index=index + 1, name=gen.name(), fit=fit_s, - sample="; ".join(map(repr, gen.generate_data(limit))) + sample="; ".join(map(repr, gen.generate_data(limit))), ) def do_p(self, arg): - """ Synonym for propose """ + """Synonym for propose""" self.do_propose(arg) def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: @@ -1508,7 +1650,7 @@ def set_generator(self, gen: Generator): gen_info.gen = gen def do_s(self, arg): - """ Synonym for set """ + """Synonym for set""" self.do_set(arg) def do_unset(self, _arg): @@ -1519,7 +1661,7 @@ def do_unset(self, _arg): self._go_next() def do_merge(self, arg: str): - """ Add this column(s) to the specified column(s), so one generator covers them all. """ + """Add this column(s) to the specified column(s), so one generator covers them all.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") @@ -1527,10 +1669,10 @@ def do_merge(self, arg: str): if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return - cols_available = functools.reduce(lambda x, y: x | y, [ - frozenset(gen.columns) - for gen in table_entry.new_generators - ]) + cols_available = functools.reduce( + lambda x, y: x | y, + [frozenset(gen.columns) for gen in table_entry.new_generators], + ) cols_to_merge = frozenset(cols) unknown_cols = cols_to_merge - cols_available if unknown_cols: @@ -1585,7 +1727,7 @@ def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): ] def do_unmerge(self, arg: str): - """ Remove this column(s) from this generator, make them a separate generator. """ + """Remove this column(s) from this generator, make them a separate generator.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") @@ -1615,10 +1757,13 @@ def do_unmerge(self, arg: str): # The existing generator will not work gen_info.gen = None # And put them into a new (empty) generator - table_entry.new_generators.insert(self.generator_index + 1, GeneratorInfo( - columns=cols, - gen=None, - )) + table_entry.new_generators.insert( + self.generator_index + 1, + GeneratorInfo( + columns=cols, + gen=None, + ), + ) self.set_prompt() def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int): @@ -1650,7 +1795,11 @@ def update_config_generators( line_no += 1 if line: if len(line) != 3: - logger.error("line {0} of file {1} does not have three values", line_no, spec_path) + logger.error( + "line {0} of file {1} does not have three values", + line_no, + spec_path, + ) if gc.go_to(f"{line[0]}.{line[1]}"): gc.do_set(line[2]) gc.do_quit("yes") diff --git a/datafaker/main.py b/datafaker/main.py index 5b79831e..c22f0979 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,8 +1,8 @@ """Entrypoint for the datafaker package.""" import asyncio -from enum import Enum import json import sys +from enum import Enum from importlib import metadata from pathlib import Path from typing import Final, Optional @@ -15,8 +15,8 @@ from datafaker.create import create_db_data, create_db_tables, create_db_vocab from datafaker.dump import dump_db_tables from datafaker.interactive import ( - update_config_tables, update_config_generators, + update_config_tables, update_missingness, ) from datafaker.make import ( @@ -68,24 +68,24 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name, config: dict | None=None): +def load_metadata_config(orm_file_name, config: dict | None = None): with open(orm_file_name) as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) tables_dict = meta_dict.get("tables", {}) if config is not None and "tables" in config: # Remove ignored tables - for (name, table_config) in config.get("tables", {}).items(): + for name, table_config in config.get("tables", {}).items(): if get_flag(table_config, "ignore"): tables_dict.pop(name, None) return meta_dict -def load_metadata(orm_file_name, config: dict | None=None): +def load_metadata(orm_file_name, config: dict | None = None): meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) -def load_metadata_for_output(orm_file_name, config: dict | None=None): +def load_metadata_for_output(orm_file_name, config: dict | None = None): """ Load metadata excluding any foreign keys pointing to ignored tables. """ @@ -94,12 +94,9 @@ def load_metadata_for_output(orm_file_name, config: dict | None=None): @app.callback() -def main(verbose: bool = Option( - False, - "--verbose", - "-v", - help="Print more information." -)): +def main( + verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") +): conf_logger(verbose) @@ -108,7 +105,7 @@ def create_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option( DF_FILENAME, - help="The name of the generators file. Must be in the current working directory." + help="The name of the generators file. Must be in the current working directory.", ), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), num_passes: int = Option(1, help="Number of passes (rows or stories) to make"), @@ -145,7 +142,9 @@ def create_data( num_passes, ) logger.debug( - "Data created in %s %s.", num_passes, "pass" if num_passes == 1 else "passes" + "Data created in %s %s.", + num_passes, + "pass" if num_passes == 1 else "passes", ) for table_name, row_count in row_counts.items(): logger.debug( @@ -210,9 +209,11 @@ def create_generators( "Statistics file (output of make-stats); default is src-stats.yaml if the " "config file references SRC_STATS, or None otherwise." ), - show_default=False + show_default=False, + ), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing Python generators file." ), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing Python generators file."), ) -> None: """Make a datafaker file of generator classes. @@ -249,7 +250,12 @@ def create_generators( def make_vocab( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - force: bool = Option(False, "--force/--no-force", "-f/+f", help="Overwrite any existing vocabulary file."), + force: bool = Option( + False, + "--force/--no-force", + "-f/+f", + help="Overwrite any existing vocabulary file.", + ), compress: bool = Option(False, help="Compress file to .gz"), only: list[str] = Option([], help="Only download this table."), ) -> None: @@ -279,7 +285,9 @@ def make_stats( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: str = Option(STATS_FILENAME), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing vocabulary file."), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing vocabulary file." + ), ) -> None: """Compute summary statistics from the source database. @@ -309,9 +317,14 @@ def make_stats( @app.command() def make_tables( - config_file: Optional[str] = Option(None, help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables"), + config_file: Optional[str] = Option( + None, + help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables", + ), orm_file: str = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing orm yaml file."), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing orm yaml file." + ), ) -> None: """Make a YAML file representing the tables in the schema. @@ -335,7 +348,9 @@ def make_tables( @app.command() def configure_tables( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path to write the configuration file to"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path to write the configuration file to" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ): """ @@ -347,10 +362,14 @@ def configure_tables( config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) # we don't pass config here so that no tables are ignored metadata = load_metadata(orm_file) - config_updated = update_config_tables(src_dsn, settings.src_schema, metadata, config) + config_updated = update_config_tables( + src_dsn, settings.src_schema, metadata, config + ) if config_updated is None: logger.debug("Cancelled") return @@ -361,7 +380,9 @@ def configure_tables( @app.command() def configure_missing( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path to write the configuration file to"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path to write the configuration file to" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ): """ @@ -373,7 +394,9 @@ def configure_missing( config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) if config_updated is None: @@ -386,9 +409,14 @@ def configure_missing( @app.command() def configure_generators( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path of the configuration file to alter"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path of the configuration file to alter" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - spec: Path = Option(None, help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively") + spec: Path = Option( + None, + help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", + ), ): """ Interactively set generators for column data. @@ -399,9 +427,13 @@ def configure_generators( config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) metadata = load_metadata(orm_file, config) - config_updated = update_config_generators(src_dsn, settings.src_schema, metadata, config, spec_path=spec) + config_updated = update_config_generators( + src_dsn, settings.src_schema, metadata, config, spec_path=spec + ) if config_updated is None: logger.debug("Cancelled") return @@ -412,12 +444,14 @@ def configure_generators( @app.command() def dump_data( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path of the configuration file to alter"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path of the configuration file to alter" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: str = Argument(help="The table to dump"), output: str | None = Option(None, help="output CSV file name"), ): - """ Dump a whole table as a CSV file (or to the console) from the destination database. """ + """Dump a whole table as a CSV file (or to the console) from the destination database.""" settings = get_settings() dst_dsn: str = settings.dst_dsn or "" assert dst_dsn != "", "Missing DST_DSN setting." @@ -427,7 +461,7 @@ def dump_data( if output == None: dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return - with open(output, 'wt', newline='') as out: + with open(output, "wt", newline="") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) @@ -452,7 +486,9 @@ def validate_config( def remove_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Truncate non-vocabulary tables in the destination schema.""" if yes: @@ -469,7 +505,9 @@ def remove_data( def remove_vocab( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Truncate vocabulary tables in the destination schema.""" if yes: @@ -487,7 +525,9 @@ def remove_vocab( def remove_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Drop all tables in the destination schema. diff --git a/datafaker/make.py b/datafaker/make.py index 1284672f..17e0d3b8 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -5,13 +5,11 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import ( - Any, Callable, Final, Mapping, Optional, Sequence, Tuple -) -import yaml +from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple import pandas as pd import snsql +import yaml from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider @@ -26,8 +24,8 @@ from datafaker.utils import ( create_db_engine, download_table, - get_property, get_flag, + get_property, get_related_table_names, get_sync_engine, get_vocabulary_table_names, @@ -73,7 +71,8 @@ class RowGeneratorInfo: @dataclass class ColumnChoice: - """ Chooses columns based on a random number in [0,1) """ + """Chooses columns based on a random number in [0,1)""" + function_name: str argument_values: list[str] @@ -84,10 +83,7 @@ def make_column_choices( return [ ColumnChoice( function_name=mg["name"], - argument_values=[ - f"{k}={v}" - for k, v in mg.get("kwargs", {}).items() - ] + argument_values=[f"{k}={v}" for k, v in mg.get("kwargs", {}).items()], ) for mg in table_config.get("missingness_generators", []) if "name" in mg @@ -122,7 +118,9 @@ def _render_value(v) -> str: if type(v) is set: return "{" + ", ".join(_render_value(x) for x in v) + "}" if type(v) is dict: - return "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" + return ( + "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" + ) if type(v) is str: return v return str(v) @@ -181,9 +179,7 @@ def _get_row_generator( return row_gen_info, columns_covered -def _get_default_generator( - column: Column -) -> RowGeneratorInfo: +def _get_default_generator(column: Column) -> RowGeneratorInfo: """Get default generator information, for the given column.""" # If it's a primary key column, we presume that primary keys are populated # automatically. @@ -215,7 +211,8 @@ def _get_default_generator( primary_key=column.primary_key, variable_names=variable_names, function_call=_get_function_call( - function_name=generator_function, positional_arguments=generator_arguments + function_name=generator_function, + positional_arguments=generator_arguments, ), ) @@ -243,10 +240,13 @@ def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: column_type = column.type if column_type.scale is None: return ("generic.numeric.float_number", {}) - return ("generic.numeric.float_number", { - "start": 0, - "end": 10 ** column_type.scale - 1, - }) + return ( + "generic.numeric.float_number", + { + "start": 0, + "end": 10**column_type.scale - 1, + }, + ) def _string_generator(column: Column) -> tuple[str, dict[str, str]]: @@ -257,7 +257,8 @@ def _string_generator(column: Column) -> tuple[str, dict[str, str]]: column_size: Optional[int] = getattr(column.type, "length", None) if column_size is None: return ("generic.text.color", {}) - return ("generic.person.password", { "length": str(column_size) }) + return ("generic.person.password", {"length": str(column_size)}) + def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ @@ -265,10 +266,13 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ if not column.primary_key: return ("generic.numeric.integer_number", {}) - return ("generic.column_value_provider.increment", { - "db_connection": "dst_db_conn", - "column": f'metadata.tables["{column.table.name}"].columns["{column.name}"]', - }) + return ( + "generic.column_value_provider.increment", + { + "db_connection": "dst_db_conn", + "column": f'metadata.tables["{column.table.name}"].columns["{column.name}"]', + }, + ) _YEAR_SUMMARY_QUERY = ( @@ -316,12 +320,12 @@ def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: sqltypes.Date: GeneratorInfo( generator="generic.datetime.date", summary_query=_YEAR_SUMMARY_QUERY, - arg_types={ "start": int, "end": int } + arg_types={"start": int, "end": int}, ), sqltypes.DateTime: GeneratorInfo( generator="generic.datetime.datetime", summary_query=_YEAR_SUMMARY_QUERY, - arg_types={ "start": int, "end": int } + arg_types={"start": int, "end": int}, ), sqltypes.Integer: GeneratorInfo( # must be before Numeric generator=_integer_generator, @@ -345,7 +349,7 @@ def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: sqltypes.String: GeneratorInfo( generator=_string_generator, choice=True, - ) + ), } @@ -358,7 +362,7 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: callable, dict of keyword arguments to pass to the callable). """ if column_t in _COLUMN_TYPE_TO_GENERATOR_INFO: - return _COLUMN_TYPE_TO_GENERATOR_INFO[column_t] + return _COLUMN_TYPE_TO_GENERATOR_INFO[column_t] # Search exhaustively for a superclass to the columns actual type for key, value in _COLUMN_TYPE_TO_GENERATOR_INFO.items(): @@ -368,8 +372,9 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: return None -def _get_generator_for_column(column_t: type) -> str | Callable[ - [type_api.TypeEngine], tuple[str, dict[str, str]]]: +def _get_generator_for_column( + column_t: type, +) -> str | Callable[[type_api.TypeEngine], tuple[str, dict[str, str]]]: """ Gets a generator from a column type. @@ -392,7 +397,7 @@ def _get_generator_and_arguments(column: Column) -> tuple[str, dict[str, str]]: generator_arguments: dict[str, str] = {} if callable(generator_function): (generator_function, generator_arguments) = generator_function(column) - return generator_function,generator_arguments + return generator_function, generator_arguments def _get_provider_for_column(column: Column) -> Tuple[list[str], str, dict[str, str]]: @@ -443,6 +448,7 @@ class _PrimaryConstraint: columns in a table comprise the primary key. Not a real constraint, but enough to write df.py. """ + def __init__(self, *columns: Column, name: str): self.name = name self.columns = columns @@ -461,15 +467,11 @@ def _get_generator_for_table( ), key=_constraint_sort_key, ) - primary_keys = [ - c for c in table.columns - if c.primary_key - ] + primary_keys = [c for c in table.columns if c.primary_key] if 1 < len(primary_keys): - unique_constraints.append(_PrimaryConstraint( - *primary_keys, - name=f"{table.name}_primary_key" - )) + unique_constraints.append( + _PrimaryConstraint(*primary_keys, name=f"{table.name}_primary_key") + ) column_choices = make_column_choices(table_config) if column_choices: nonnull_columns = { @@ -522,7 +524,7 @@ def make_vocabulary_tables( config: Mapping, overwrite_files: bool, compress: bool, - table_names: set[str] | None=None, + table_names: set[str] | None = None, ): """ Extracts the data from the source database for each @@ -539,7 +541,10 @@ def make_vocabulary_tables( else: invalid_names = table_names - vocab_names if invalid_names: - logger.error("The following names are not the names of vocabulary tables: %s", invalid_names) + logger.error( + "The following names are not the names of vocabulary tables: %s", + invalid_names, + ) logger.info("Valid names are: %s", vocab_names) return for table_name in table_names: @@ -584,7 +589,7 @@ def make_table_generators( # pylint: disable=too-many-locals tables: list[TableGeneratorInfo] = [] vocabulary_tables: list[VocabularyTableGeneratorInfo] = [] vocab_names = get_vocabulary_table_names(config) - for (table_name, table) in metadata.tables.items(): + for table_name, table in metadata.tables.items(): if table_name in vocab_names: related = get_related_table_names(table) related_non_vocab = related.difference(vocab_names) @@ -593,16 +598,18 @@ def make_table_generators( # pylint: disable=too-many-locals "Making table '%s' a vocabulary table requires that also the" " related tables (%s) be also vocabulary tables.", table.name, - related_non_vocab + related_non_vocab, ) vocabulary_tables.append( _get_generator_for_existing_vocabulary_table(table) ) else: - tables.append(_get_generator_for_table( - tables_config.get(table.name, {}), - table, - )) + tables.append( + _get_generator_for_table( + tables_config.get(table.name, {}), + table, + ) + ) story_generators = _get_story_generators(config) @@ -766,9 +773,7 @@ def fix_type(value): def fix_types(dics): - return [{ - k: fix_type(v) for k, v in dic.items() - } for dic in dics] + return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] async def make_src_stats( @@ -793,7 +798,10 @@ async def make_src_stats( async with DbConnection(engine) as db_conn: return await make_src_stats_connection(config, db_conn, metadata) -async def make_src_stats_connection(config: Mapping, db_conn: DbConnection, metadata: MetaData): + +async def make_src_stats_connection( + config: Mapping, db_conn: DbConnection, metadata: MetaData +): date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/providers.py b/datafaker/providers.py index b07f2b7f..1ebd5bf5 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -5,8 +5,8 @@ from mimesis import Datetime, Text from mimesis.providers.base import BaseDataProvider, BaseProvider -from sqlalchemy import Connection, Column -from sqlalchemy.sql import functions, select, func +from sqlalchemy import Column, Connection +from sqlalchemy.sql import func, functions, select class ColumnValueProvider(BaseProvider): @@ -29,12 +29,12 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed = None, **kwargs): + def __init__(self, *, seed=None, **kwargs): super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} def increment(self, db_connection: Connection, column: Column) -> int: - """ Return incrementing value for the column specified. """ + """Return incrementing value for the column specified.""" name = f"{column.table.name}.{column.name}" result = self.accumulators.get(name, None) if result == None: diff --git a/datafaker/remove.py b/datafaker/remove.py index c0a6c47f..a316619e 100644 --- a/datafaker/remove.py +++ b/datafaker/remove.py @@ -1,7 +1,7 @@ """Functions and classes to undo the operations in create.py.""" from typing import Any, Mapping -from sqlalchemy import delete, MetaData +from sqlalchemy import MetaData, delete from datafaker.settings import get_settings from datafaker.utils import ( @@ -9,23 +9,18 @@ get_sync_engine, get_vocabulary_table_names, logger, - remove_vocab_foreign_key_constraints, reinstate_vocab_foreign_key_constraints, + remove_vocab_foreign_key_constraints, sorted_non_vocabulary_tables, ) -def remove_db_data( - metadata: MetaData, config: Mapping[str, Any] -) -> None: +def remove_db_data(metadata: MetaData, config: Mapping[str, Any]) -> None: """Truncate the synthetic data tables but not the vocabularies.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" remove_db_data_from( - metadata, - config, - settings.dst_dsn, - schema_name=settings.dst_schema + metadata, config, settings.dst_dsn, schema_name=settings.dst_schema ) @@ -33,9 +28,7 @@ def remove_db_data_from( metadata: MetaData, config: Mapping[str, Any], db_dsn: str, schema_name: str | None ) -> None: """Truncate the synthetic data tables but not the vocabularies.""" - dst_engine = get_sync_engine( - create_db_engine(db_dsn, schema_name=schema_name) - ) + dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) with dst_engine.connect() as dst_conn: for table in reversed(sorted_non_vocabulary_tables(metadata, config)): @@ -44,7 +37,9 @@ def remove_db_data_from( dst_conn.commit() -def remove_db_vocab(metadata: MetaData, meta_dict: Mapping[str, Any], config: Mapping[str, Any]) -> None: +def remove_db_vocab( + metadata: MetaData, meta_dict: Mapping[str, Any], config: Mapping[str, Any] +) -> None: """Truncate the vocabulary tables.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 51f4b038..303c2c76 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,14 +1,16 @@ +from typing import Callable + import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql -from sqlalchemy.sql import sqltypes, schema -from typing import Callable +from sqlalchemy.sql import schema, sqltypes from datafaker.utils import make_foreign_key_name table_component_t = dict[str, any] table_t = dict[str, table_component_t] + def simple(type_): """ Parses a simple sqltypes type. @@ -17,29 +19,34 @@ def simple(type_): """ return parsy.string(type_.__name__).result(type_) + def integer(): """ Parses an integer, outputting that integer. """ return parsy.regex(r"-?[0-9]+").map(int) + def integer_arguments(): """ Parses a list of integers. The integers are surrounded by brackets and separated by a comma and space. """ - return parsy.string("(") >> ( - integer().sep_by(parsy.string(", ")) - ) << parsy.string(")") + return ( + parsy.string("(") >> (integer().sep_by(parsy.string(", "))) << parsy.string(")") + ) + def numeric_type(type_): """ Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. """ - return parsy.string(type_.__name__ - ) >> integer_arguments().optional([]).combine(type_) + return parsy.string(type_.__name__) >> integer_arguments().optional([]).combine( + type_ + ) + def string_type(type_): @parsy.generate(type_.__name__) @@ -56,8 +63,10 @@ def st_parser(): parsy.string(' COLLATE "') >> parsy.regex(r'[^"]*') << parsy.string('"') ).optional() return type_(length=length, collation=collation) + return st_parser + def time_type(type_, pg_type): @parsy.generate(type_.__name__) def pgt_parser(): @@ -70,18 +79,22 @@ def pgt_parser(): parsy.string("(") >> integer() << parsy.string(")") ).optional() timezone: str | None = yield ( - parsy.string(" WITH") >> ( - parsy.string(" ").result(True) | parsy.string("OUT ").result(False) - ) << parsy.string("TIME ZONE") + parsy.string(" WITH") + >> (parsy.string(" ").result(True) | parsy.string("OUT ").result(False)) + << parsy.string("TIME ZONE") ).optional(False) if precision is None and not timezone: # normal sql type return type_ return pg_type(precision=precision, timezone=timezone) + return pgt_parser + SIMPLE_TYPE_PARSER = parsy.alt( - parsy.string("DOUBLE PRECISION").result(sqltypes.DOUBLE_PRECISION), # must be before DOUBLE + parsy.string("DOUBLE PRECISION").result( + sqltypes.DOUBLE_PRECISION + ), # must be before DOUBLE simple(sqltypes.FLOAT), simple(sqltypes.DOUBLE), simple(sqltypes.INTEGER), @@ -110,6 +123,7 @@ def pgt_parser(): time_type(sqltypes.TIME, postgresql.types.TIME), ) + @parsy.generate def type_parser(): base = yield SIMPLE_TYPE_PARSER @@ -118,6 +132,7 @@ def type_parser(): return base return postgresql.ARRAY(base, dimensions=dimensions) + def column_to_dict(column: Column, dialect: Dialect) -> str: type_ = column.type if isinstance(type_, postgresql.DOMAIN): @@ -139,6 +154,7 @@ def column_to_dict(column: Column, dialect: Dialect) -> str: result["foreign_keys"] = foreign_keys return result + def dict_to_column( table_name, col_name, @@ -156,7 +172,7 @@ def dict_to_column( ForeignKey( fk, name=make_foreign_key_name(table_name, col_name), - ondelete='CASCADE', + ondelete="CASCADE", ) for fk in rep["foreign_keys"] if not ignore_fk(fk) @@ -171,21 +187,18 @@ def dict_to_column( nullable=rep.get("nullable", None), ) + def dict_to_unique(rep: dict) -> schema.UniqueConstraint: - return schema.UniqueConstraint( - *rep.get("columns", []), - name=rep.get("name", None) - ) + return schema.UniqueConstraint(*rep.get("columns", []), name=rep.get("name", None)) + def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: return { "name": constraint.name, - "columns": [ - str(col.name) - for col in constraint.columns - ] + "columns": [str(col.name) for col in constraint.columns], } + def table_to_dict(table: Table, dialect: Dialect) -> table_t: """ Converts a SQL Alchemy Table object into a @@ -203,6 +216,7 @@ def table_to_dict(table: Table, dialect: Dialect) -> table_t: ], } + def dict_to_table( name: str, meta: MetaData, @@ -212,15 +226,17 @@ def dict_to_table( return Table( name, meta, - *[ dict_to_column(name, colname, col, ignore_fk) + *[ + dict_to_column(name, colname, col, ignore_fk) for (colname, col) in table_dict.get("columns", {}).items() ], - *[ dict_to_unique(constraint) - for constraint in table_dict.get("unique", []) - ], + *[dict_to_unique(constraint) for constraint in table_dict.get("unique", [])], ) -def metadata_to_dict(meta: MetaData, schema_name: str | None, engine: Engine) -> dict[str, table_t]: + +def metadata_to_dict( + meta: MetaData, schema_name: str | None, engine: Engine +) -> dict[str, table_t]: """ Converts a SQL Alchemy MetaData object into a Python object ready for conversion to YAML. @@ -248,10 +264,7 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): return tables_dict[fk_bits[0]].get("ignore", False) -def dict_to_metadata( - obj: dict, - config_for_output: dict=None -) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict = None) -> MetaData: """ Converts a dict to a SQL Alchemy MetaData object. @@ -268,6 +281,6 @@ def dict_to_metadata( else: ignore_fk = lambda _: False meta = MetaData() - for (k, td) in tables_dict.items(): + for k, td in tables_dict.items(): dict_to_table(k, meta, td, ignore_fk) return meta diff --git a/datafaker/utils.py b/datafaker/utils.py index 33b8c846..883e0969 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -1,27 +1,20 @@ """Utility functions.""" import ast +import gzip +import importlib.util import json import logging import sys -import importlib.util from pathlib import Path from types import ModuleType from typing import Any, Final, Mapping, Optional, Union -import gzip +import sqlalchemy import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate from psycopg2.errors import UndefinedObject -import sqlalchemy -from sqlalchemy import ( - Connection, - Engine, - ForeignKey, - create_engine, - event, - select, -) +from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.exc import IntegrityError, ProgrammingError from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -96,10 +89,15 @@ def open_compressed_file(file_name): def table_row_count(table: Table, conn: Connection) -> int: return conn.execute( - select(sqlalchemy.func.count()).select_from(sqlalchemy.table( - table.name, - *[sqlalchemy.column(col.name) for col in table.primary_key.columns.values()], - )) + select(sqlalchemy.func.count()).select_from( + sqlalchemy.table( + table.name, + *[ + sqlalchemy.column(col.name) + for col in table.primary_key.columns.values() + ], + ) + ) ).scalar_one() @@ -117,10 +115,7 @@ def download_table( rowcount = table_row_count(table, conn) count = 0 for row in conn.execute(stmt).mappings(): - result = { - str(col_name): value - for (col_name, value) in row.items() - } + result = {str(col_name): value for (col_name, value) in row.items()} yamlfile.write(yaml.dump([result]).encode()) count += 1 if count % MAKE_VOCAB_PROGRESS_REPORT_EVERY == 0: @@ -128,7 +123,7 @@ def download_table( "written row %d of %d, %.1f%%", count, rowcount, - 100*count/rowcount, + 100 * count / rowcount, ) @@ -213,6 +208,7 @@ class StdoutHandler(logging.Handler): A handler that writes to stdout. We aren't using StreamHandler because that confuses typer.testing.CliRunner """ + def flush(self): self.acquire() try: @@ -236,6 +232,7 @@ class StderrHandler(logging.Handler): A handler that writes to stderr. We aren't using StreamHandler because that confuses typer.testing.CliRunner """ + def flush(self): self.acquire() try: @@ -276,8 +273,8 @@ def conf_logger(verbose: bool) -> None: handlers=[stdout_handler, stderr_handler], force=True, ) - logging.getLogger('asyncio').setLevel(logging.WARNING) - logging.getLogger('blib2to3.pgen2.driver').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) def get_flag(maybe_dict, key): @@ -370,7 +367,9 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] for fk in vocab_table.foreign_key_constraints: - logger.debug("Dropping constraint %s from table %s", fk.name, vocab_table_name) + logger.debug( + "Dropping constraint %s from table %s", fk.name, vocab_table_name + ) with Session(dst_engine) as session: session.begin() try: @@ -378,7 +377,11 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): session.commit() except IntegrityError: session.rollback() - logger.exception("Dropping table %s key constraint %s failed:", vocab_table_name, fk.name) + logger.exception( + "Dropping table %s key constraint %s failed:", + vocab_table_name, + fk.name, + ) except ProgrammingError as e: session.rollback() if type(e.orig) is UndefinedObject: @@ -392,7 +395,9 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] try: - for (column_name, column_dict) in meta_dict["tables"][vocab_table_name]["columns"].items(): + for column_name, column_dict in meta_dict["tables"][vocab_table_name][ + "columns" + ].items(): fk_targets = column_dict.get("foreign_keys", []) if fk_targets: fk = ForeignKeyConstraint( @@ -407,7 +412,9 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng session.execute(AddConstraint(fk)) session.commit() except IntegrityError: - logger.exception("Restoring table %s foreign keys failed:", vocab_table_name) + logger.exception( + "Restoring table %s foreign keys failed:", vocab_table_name + ) def stream_yaml(yaml_file_handle): @@ -437,7 +444,7 @@ def topological_sort(input_nodes, get_dependencies_fn): Topoligically sort input_nodes and find any cycles. Returns a pair (sorted, cycles). - + 'sorted' is a list of all the elements of input_nodes sorted so that dependencies returned by get_dependencies_fn come after nodes that depend on them. Cycles are @@ -478,23 +485,21 @@ def topological_sort(input_nodes, get_dependencies_fn): elif n in grey: # n is in a cycle cycle_start = grey.index(n) - cycles.append(grey[cycle_start:len(grey)]) + cycles.append(grey[cycle_start : len(grey)]) return (black, cycles) def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Table]: - table_names = set( - metadata.tables.keys() - ).difference( + table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) (sorted, cycles) = topological_sort( - table_names, - lambda tn: get_related_table_names(metadata.tables[tn]) + table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: logger.warning(f"Cycle detected between tables: {cycle}") - return [ metadata.tables[tn] for tn in sorted ] + return [metadata.tables[tn] for tn in sorted] + def generators_require_stats(config: Mapping) -> bool: """ @@ -527,14 +532,16 @@ def generators_require_stats(config: Mapping) -> bool: if any(name == "SRC_STATS" for name in names): stats_required = True except SyntaxError as e: - errors.append(( - "Syntax error in argument %d of %s: %s\n%s\n%s", - n + 1, - where, - e.msg, - arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), - )) + errors.append( + ( + "Syntax error in argument %d of %s: %s\n%s\n%s", + n + 1, + where, + e.msg, + arg, + " " * e.offset + "^" * max(1, e.end_offset - e.offset), + ) + ) for k, arg in call.get("kwargs", {}).items(): if type(arg) is str: try: @@ -546,14 +553,16 @@ def generators_require_stats(config: Mapping) -> bool: if any(name == "SRC_STATS" for name in names): stats_required = True except SyntaxError as e: - errors.append(( - "Syntax error in argument %s of %s: %s\n%s\n%s", - k, - where, - e.msg, - arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), - )) + errors.append( + ( + "Syntax error in argument %s of %s: %s\n%s\n%s", + k, + where, + e.msg, + arg, + " " * e.offset + "^" * max(1, e.end_offset - e.offset), + ) + ) for error in errors: logger.error(*error) return stats_required diff --git a/docs/source/_static/config_schema.html b/docs/source/_static/config_schema.html index ca0baa07..e78949f0 100644 --- a/docs/source/_static/config_schema.html +++ b/docs/source/_static/config_schema.html @@ -1 +1,2759 @@ - Datafaker Config

datafaker Config

Type: object

A datafaker configuration YAML file

No Additional Properties

Type: boolean

Run source-statistics queries using asyncpg.

Type: string

The name of a local Python module of row generators (excluding .py).

Type: string

The name of a local Python module of story generators (excluding .py).

Type: array

An array of source statistics queries.

Each item of this array must be:

Type: object
No Additional Properties

Type: string

A name for the query, which will be used in the stats file.

Type: string

A SQL query.

Type: string

A SmartNoise SQL query.

Type: number

The differential privacy epsilon value for the DP query.

Type: number

The differential privacy delta value for the DP query.

Type: object

See https://docs.smartnoise.org/sql/metadata.html#yaml-format.

All properties whose name matches the following regular expression must respect the following conditions

Property name regular expression: ^(?!(max_ids|row_privacy|sample_max_ids|censor_dims|clamp_counts|clamp_columns|use_dpsu)).*$
Type: object
No Additional Properties

Type: array of object

An array of story generators.

Each item of this array must be:

Type: object
No Additional Properties

Type: string

The full name of a story generator (e.g. mystorygenerators.short_story).

Type: array

Positional arguments to pass to the story generator.

Type: object

Keyword arguments to pass to the story generator.

Type: integer

The number of times to call the story generator per pass.

Type: integer

The maximum number of tries to respect a uniqueness constraint.

Type: object

Table configurations.

All properties whose name matches the following regular expression must respect the following conditions

Property name regular expression: .*
Type: object

A table configuration.

No Additional Properties

Type: boolean

Whether to completely ignore this table.

Type: boolean

Whether to export the table data.

Type: integer

The number of rows to generate per pass.

Type: array of object

An array of row generators to create column values.

Each item of this array must be:

Type: object

Type: string

The name of a (built-in or custom) function (e.g. max or myrowgenerators.my_gen).

Type: array

Positional arguments to pass to the function.

Type: object

Keyword arguments to pass to the function.

Type: array of string or string

One or more columns to assign the return value to.

Each item of this array must be:

\ No newline at end of file + + + + + + + + + + + + + + + + datafaker Config + + + +

datafaker Config

Type: object
+

A datafaker configuration YAML file

+
No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Run source-statistics queries using asyncpg.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a local Python module of row generators (excluding .py).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a local Python module of story generators (excluding .py).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Objects that need to be instantiated from the row and story generators modules.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

An array of source statistics queries.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A name for the query, which will be used in the stats file.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string
+

Comments to be copied into the src-stats.yaml file describing the query results.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A SQL query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A SmartNoise SQL query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: number
+

The differential privacy epsilon value for the DP query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: number
+

The differential privacy delta value for the DP query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

See https://docs.smartnoise.org/sql/metadata.html#yaml-format.

+
+ + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: integer
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+

+ +

+

All properties whose name matches the following regular expression must respect the following conditions

+ Property name regular expression: ^(?!(max_ids|row_privacy|sample_max_ids|censor_dims|clamp_counts|clamp_columns|use_dpsu)).*$ +
+ + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

An array of story generators.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The full name of a story generator (e.g. mystorygenerators.short_story).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

Positional arguments to pass to the story generator.

+
+ + + + + + No Additional Items +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the story generator.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The number of times to call the story generator per pass.

+
+ + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The maximum number of tries to respect a uniqueness constraint.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Table configurations.

+
+ + + + + + +
+
+
+

+ +

+
+ +
+

+ +

+

All properties whose name matches the following regular expression must respect the following conditions

+ Property name regular expression: .* +
+ + Type: object
+

A table configuration.

+
No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether to completely ignore this table.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether to export the table data.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether the table is a Primary Private table (perhaps a table of patients).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The number of rows to generate per pass.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

An array of row generators to create column values.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a (built-in or custom) function (e.g. max or myrowgenerators.my_gen).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

Positional arguments to pass to the function.

+
+ + + + + + No Additional Items +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the function.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string or string
+

One or more columns to assign the return value to.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

Function to generate a set of nullable columns that should not be null

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a (built-in or custom) function (e.g. column_presence.sampled).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the function.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string
+

Column names that might be returned.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + \ No newline at end of file diff --git a/docs/source/custom_generators.rst b/docs/source/custom_generators.rst index c29e340c..73d04b64 100644 --- a/docs/source/custom_generators.rst +++ b/docs/source/custom_generators.rst @@ -60,4 +60,4 @@ Again, you must define your own; ``datafaker`` provides no built-in story genera You can put your story generators in their own Python file, or you can re-use your row generators file if you like. -A story generator is a Python Generator function (a function that calls ``yield`` to return multiple values rather than ``return`` a single one). \ No newline at end of file +A story generator is a Python Generator function (a function that calls ``yield`` to return multiple values rather than ``return`` a single one). diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index c588c0bc..8cf833a8 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -124,7 +124,7 @@ Some of these functions take arguments, that we can assign like this: Anyway, we now need to remake the generators (``create-generators``) and re-run them (``create-data``): .. code-block:: console - + $ datafaker create-generators --force $ datafaker create-data --num-passes 15 diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 48e52a03..43722aa6 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -110,7 +110,7 @@ This command will start an interactive command shell. Don't be intimidated, just columns data help next peek private select vocabulary counts empty ignore generate previous quit tables - (table: myfirsttable) + (table: myfirsttable) You can also get help for any of the commands listed; for example to see help for the ``vocabulary`` command type ``? vocabulary`` or ``help vocabulary``: @@ -134,7 +134,7 @@ Press the Tab key again to see these options: .. code-block:: console (table: actor) help p - peek previous private + peek previous private (table: actor) help p Now you can continue with r-i-tab to get ``private``, r-e-tab to get ``previous`` or e-tab to get ``peek``. This can be very useful; try pressing Tab twice on an empty line to see quickly all the possible commands, for example! @@ -372,7 +372,7 @@ To describe "null-partitioned grouped", let us make the generator much more comp | None | None | None | Pencil on tracing paper | | None | 18.5 | 24.3 | Lithograph from an illustrated book of poems and four lithographs | +----------+---------------+---------------+------------------------------------------------------------------------------------------------+ - (artwork.depth_cm,width_cm,height_cm,medium) + (artwork.depth_cm,width_cm,height_cm,medium) Here we can see that Moma understandably does not record depths for 2D artworks so we have many NULLs in that column. If we try to apply the standard normal or lognormal to data with many NULLs, it will ignore those rows with any NULLs. diff --git a/tests/test_create.py b/tests/test_create.py index 0fe1bf3e..333c01a2 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,9 +1,9 @@ """Tests for the create module.""" import itertools as itt -from collections import Counter import os -from pathlib import Path import random +from collections import Counter +from pathlib import Path from typing import Any, Generator, Tuple from unittest.mock import MagicMock, call, patch @@ -11,23 +11,26 @@ from sqlalchemy.schema import Table from datafaker.base import TableGenerator -from datafaker.create import ( - create_db_vocab, - populate, -) +from datafaker.create import create_db_vocab, populate from datafaker.remove import remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict from tests.utils import DatafakerTestCase, GeneratesDBTestCase + class TestCreate(GeneratesDBTestCase): """Test the make_table_generators function.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def test_create_vocab(self) -> None: """Test the create_db_vocab function.""" - with patch.dict(os.environ, {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, clear=True): + with patch.dict( + os.environ, + {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, + clear=True, + ): config = { "tables": { "player": { @@ -83,7 +86,7 @@ def test_make_table_generators(self) -> None: class TestPopulate(DatafakerTestCase): - """ Test create.populate. """ + """Test create.populate.""" def test_populate(self) -> None: """Test the populate function.""" diff --git a/tests/test_dump.py b/tests/test_dump.py index 4293f285..2d5ed268 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -1,27 +1,32 @@ """Tests for the base module.""" -from sqlalchemy.schema import MetaData -from tests.utils import RequiresDBTestCase from unittest.mock import MagicMock, call, patch +from sqlalchemy.schema import MetaData + from datafaker.dump import dump_db_tables +from tests.utils import RequiresDBTestCase + class DumpTests(RequiresDBTestCase): """Testing configure-tables.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: - """ Test dump-data. """ + """Test dump-data.""" TEST_OUTPUT_FILE = "test_output_file_object" metadata = MetaData() metadata.reflect(self.engine) dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) - make_csv_writer.assert_has_calls([ - call().writerow(["id", "given_name", "family_name"]), - call().writerow((1, 'Mark', 'Samson')), - call().writerow((2, 'Tim', 'Friedman')), - call().writerow((3, 'Pierre', 'Marchmont')), - ]) + make_csv_writer.assert_has_calls( + [ + call().writerow(["id", "given_name", "family_name"]), + call().writerow((1, "Mark", "Samson")), + call().writerow((2, "Tim", "Friedman")), + call().writerow((3, "Pierre", "Marchmont")), + ] + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 00f45478..418b4f96 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -6,14 +6,15 @@ from sqlalchemy import create_engine, inspect from typer.testing import CliRunner -from tests.utils import RequiresDBTestCase - from datafaker.main import app +from tests.utils import RequiresDBTestCase # pylint: disable=subprocess-run-check + class DBFunctionalTestCase(RequiresDBTestCase): """End-to-end tests that require a database.""" + dump_file_path = "src.dump" database_name = "src" schema_name = "public" @@ -33,7 +34,7 @@ class DBFunctionalTestCase(RequiresDBTestCase): generator_file_paths = tuple( map(Path, ("story_generators.py", "row_generators.py")), ) - #dump_file_path = Path("dst.dump") + # dump_file_path = Path("dst.dump") config_file_path = Path("example_config2.yaml") stats_file_path = Path("example_stats.yaml") @@ -430,7 +431,7 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args, expected_error: str=None, env={}): + def invoke(self, *args, expected_error: str = None, env={}): res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) @@ -513,12 +514,15 @@ def test_unique_constraint_fail(self) -> None: ) self.assertEqual("", completed_process.stderr) self.assertEqual( - ("Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.full_row_story'\n" - "Generating data for story 'story_generators.long_story'\n" - "Generating data for story 'story_generators.long_story'\n") * 3, + ( + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.full_row_story'\n" + "Generating data for story 'story_generators.long_story'\n" + "Generating data for story 'story_generators.long_story'\n" + ) + * 3, completed_process.stdout, ) @@ -529,7 +533,7 @@ def test_unique_constraint_fail(self) -> None: f"--orm-file={self.alt_orm_file_path}", f"--df-file={self.alt_datafaker_file_path}", "--num-passes=1", - expected_error = ( + expected_error=( "Failed to satisfy unique constraints for table unique_constraint_test" ), ) @@ -538,7 +542,7 @@ def test_unique_constraint_fail(self) -> None: def test_create_schema(self) -> None: """Check that we create a destination schema if it doesn't exist.""" - env = { "dst_schema": "doesntexistyetschema" } + env = {"dst_schema": "doesntexistyetschema"} engine = create_engine(self.env["dst_dsn"]) inspector = inspect(engine) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 386aeaa9..94872845 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1,58 +1,66 @@ """ Tests for the base module. """ import copy -from dataclasses import dataclass import random import re +from dataclasses import dataclass +from unittest.mock import MagicMock, Mock, patch + from sqlalchemy import insert, select +from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.interactive import ( DbCmd, - TableCmd, GeneratorCmd, MissingnessCmd, + TableCmd, update_config_generators, ) -from datafaker.generators import NullPartitionedNormalGeneratorFactory - -from tests.utils import RequiresDBTestCase, GeneratesDBTestCase -from unittest.mock import MagicMock, Mock, patch +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase class TestDbCmdMixin(DbCmd): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.reset() + def reset(self): self.messages: list[tuple[str, list, dict[str, any]]] = [] self.headings: list[str] = [] self.rows: list[list[str]] = [] self.column_items: list[str] = [] self.columns: dict[str, list[str]] = {} + def print(self, text: str, *args, **kwargs): self.messages.append((text, args, kwargs)) + def print_table(self, headings: list[str], rows: list[list[str]]): self.headings = headings self.rows = rows + def print_table_by_columns(self, columns: dict[str, list[str]]): self.columns = columns + def columnize(self, items: list[str]): self.column_items.append(items) + def ask_save(self) -> str: return "yes" class TestTableCmd(TableCmd, TestDbCmdMixin): - """ TableCmd but mocked """ + """TableCmd but mocked""" class ConfigureTablesTests(RequiresDBTestCase): """Testing configure-tables.""" + def _get_cmd(self, config) -> TestTableCmd: return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) class ConfigureTablesSrcTests(ConfigureTablesTests): """Testing configure-tables with src.dump.""" + dump_file_path = "src.dump" database_name = "src" schema_name = "public" @@ -70,11 +78,15 @@ def test_table_name_prompts(self) -> None: for t in reversed(table_names): self.assertIn(t, tc.prompt) tc.do_previous("") - self.assertListEqual(tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})]) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] + ) tc.reset() bad_table_name = "notarealtable" tc.do_next(bad_table_name) - self.assertListEqual(tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})]) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] + ) tc.reset() good_table_name = table_names[2] tc.do_next(good_table_name) @@ -107,9 +119,13 @@ def test_null_configuration(self) -> None: tc.do_private("") tc.do_quit("") tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue(tables["unique_constraint_test"].get("primary_private", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) def test_null_table_configuration(self) -> None: """A table still works if its configuration is None.""" @@ -123,9 +139,13 @@ def test_null_table_configuration(self) -> None: tc.do_private("") tc.do_quit("") tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue(tables["unique_constraint_test"].get("primary_private", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) def test_configure_tables(self) -> None: """Test that we can change columns to ignore, vocab or generate.""" @@ -142,7 +162,7 @@ def test_configure_tables(self) -> None: }, "empty_vocabulary": { "private": True, - } + }, }, } with self._get_cmd(config) as tc: @@ -159,9 +179,13 @@ def test_configure_tables(self) -> None: tc.do_empty("") tc.do_quit("") tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertFalse(tables["unique_constraint_test"].get("primary_private", False)) + self.assertFalse( + tables["unique_constraint_test"].get("primary_private", False) + ) self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) self.assertTrue(tables["no_pk_test"].get("ignore", False)) @@ -189,10 +213,7 @@ def test_print_data(self) -> None: person_table = self.metadata.tables["person"] with self.engine.connect() as conn: person_rows = conn.execute(select(person_table)).mappings().fetchall() - person_data = { - row["person_id"]: row - for row in person_rows - } + person_data = {row["person_id"]: row for row in person_rows} name_set = {row["name"] for row in person_rows} person_headings = ["person_id", "name", "research_opt_out", "stored_from"] with self._get_cmd({}) as tc: @@ -224,11 +245,15 @@ def test_print_data(self) -> None: tc.reset() tc.do_data(f"{to_get_count} name 13") self.assertEqual(len(tc.column_items), 1) - self.assertEqual(set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set))) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) + ) tc.reset() tc.do_data(f"{to_get_count} name 16") self.assertEqual(len(tc.column_items), 1) - self.assertEqual(set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set))) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) + ) def test_list_tables(self): """Test that we can list the tables""" @@ -252,7 +277,7 @@ def test_list_tables(self): person_listed = False unique_constraint_test_listed = False no_pk_test_listed = False - for (text, args, kwargs) in tc.messages: + for text, args, kwargs in tc.messages: if args[2] == "person": self.assertFalse(person_listed) person_listed = True @@ -277,7 +302,8 @@ def test_list_tables(self): class ConfigureTablesInstrumentsTests(ConfigureTablesTests): - """ Testing configure-tables with the instrument.sql database. """ + """Testing configure-tables with the instrument.sql database.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -300,10 +326,28 @@ def test_sanity_checks_both(self): tc.reset() tc.do_quit("") self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, ("model", "manufacturer"), {})) - self.assertEqual(tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {})) - self.assertEqual(tc.messages[4], (TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, ("signature_model", "player"), {})) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) + self.assertEqual( + tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[4], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) def test_sanity_checks_warnings_only(self): config = { @@ -324,9 +368,25 @@ def test_sanity_checks_warnings_only(self): tc.do_vocabulary("") tc.reset() tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_CHANGING, ("manufacturer", "ignore", "vocabulary"), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, ("signature_model", "player"), {})) + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("manufacturer", "ignore", "vocabulary"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) def test_sanity_checks_errors_only(self): config = { @@ -347,16 +407,34 @@ def test_sanity_checks_errors_only(self): tc.do_empty("") tc.reset() tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_CHANGING, ("signature_model", "generate", "empty"), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, ("model", "manufacturer"), {})) + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("signature_model", "generate", "empty"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): - """ GeneratorCmd but mocked """ + """GeneratorCmd but mocked""" + def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: """ - Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples])""" + Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples]) + """ return { kw["name"]: (kw["index"], kw["fit"], kw["sample"].split("; ")) for (s, _, kw) in self.messages @@ -365,7 +443,8 @@ def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: class ConfigureGeneratorsTests(RequiresDBTestCase): - """ Testing configure-generators. """ + """Testing configure-generators.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -374,7 +453,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_null_configuration(self): - """ Test that the tables having null configuration does not break. """ + """Test that the tables having null configuration does not break.""" config = { "tables": None, } @@ -388,7 +467,7 @@ def test_null_configuration(self): self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) def test_null_table_configuration(self): - """ Test that a table having null configuration does not break. """ + """Test that a table having null configuration does not break.""" config = { "tables": { "model": None, @@ -415,10 +494,14 @@ def test_prompts(self) -> None: else: self.assertNotIn("[pk]", gc.prompt) gc.do_next("") - self.assertListEqual(gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})]) + self.assertListEqual( + gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})] + ) gc.reset() for table_name, table_meta in reversed(list(self.metadata.tables.items())): - for column_name, column_meta in reversed(list(table_meta.columns.items())): + for column_name, column_meta in reversed( + list(table_meta.columns.items()) + ): self.assertIn(table_name, gc.prompt) self.assertIn(column_name, gc.prompt) if column_meta.primary_key: @@ -426,19 +509,20 @@ def test_prompts(self) -> None: else: self.assertNotIn("[pk]", gc.prompt) gc.do_previous("") - self.assertListEqual(gc.messages, [(GeneratorCmd.ERROR_ALREADY_AT_START, (), {})]) + self.assertListEqual( + gc.messages, [(GeneratorCmd.ERROR_ALREADY_AT_START, (), {})] + ) gc.reset() bad_table_name = "notarealtable" gc.do_next(bad_table_name) - self.assertListEqual(gc.messages, [( - GeneratorCmd.ERROR_NO_SUCH_TABLE_OR_COLUMN, - (bad_table_name,), - {} - )]) + self.assertListEqual( + gc.messages, + [(GeneratorCmd.ERROR_NO_SUCH_TABLE_OR_COLUMN, (bad_table_name,), {})], + ) gc.reset() def test_set_generator_mimesis(self): - """ Test that we can set one generator to a mimesis generator. """ + """Test that we can set one generator to a mimesis generator.""" with self._get_cmd({}) as gc: TABLE = "model" COLUMN = "name" @@ -455,7 +539,7 @@ def test_set_generator_mimesis(self): ) def test_set_generator_distribution(self): - """ Test that we can set one generator to gaussian. """ + """Test that we can set one generator to gaussian.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "frequency" @@ -470,12 +554,17 @@ def test_set_generator_distribution(self): row_gen = row_gens[0] self.assertEqual(row_gen["name"], GENERATOR) self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen["kwargs"], { - "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', - }) + self.assertDictEqual( + row_gen["kwargs"], + { + "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', + "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") self.assertEqual( gc.config["src-stats"][0]["query"], @@ -483,7 +572,7 @@ def test_set_generator_distribution(self): ) def test_set_generator_distribution_directly(self): - """ Test that we can set one generator to gaussian without going through propose. """ + """Test that we can set one generator to gaussian without going through propose.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "frequency" @@ -494,7 +583,9 @@ def test_set_generator_distribution_directly(self): self.assertListEqual(gc.messages, []) gc.do_quit("") self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") self.assertEqual( gc.config["src-stats"][0]["query"], @@ -502,7 +593,7 @@ def test_set_generator_distribution_directly(self): ) def test_set_generator_choice(self): - """ Test that we can set one generator to uniform choice. """ + """Test that we can set one generator to uniform choice.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "frequency" @@ -517,19 +608,26 @@ def test_set_generator_choice(self): row_gen = row_gens[0] self.assertEqual(row_gen["name"], GENERATOR) self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen["kwargs"], { - "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', - }) + self.assertDictEqual( + row_gen["kwargs"], + { + "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}") + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) + self.assertEqual( + gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}" + ) self.assertEqual( gc.config["src-stats"][0]["query"], f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", ) def test_weighted_choice_generator_generates_choices(self): - """ Test that propose and compare show weighted_choice's values. """ + """Test that propose and compare show weighted_choice's values.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "position" @@ -546,7 +644,7 @@ def test_weighted_choice_generator_generates_choices(self): self.assertSubset(set(gc.columns[col_heading]), VALUES) def test_merge_columns(self): - """ Test that we can merge columns and set a multivariate generator """ + """Test that we can merge columns and set a multivariate generator""" TABLE = "string" COLUMN_1 = "frequency" COLUMN_2 = "position" @@ -586,7 +684,7 @@ def test_merge_columns(self): self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) def test_unmerge_columns(self): - """ Test that we can unmerge columns and generators are removed """ + """Test that we can unmerge columns and generators are removed""" TABLE = "string" COLUMN_1 = "frequency" COLUMN_2 = "position" @@ -597,7 +695,7 @@ def test_unmerge_columns(self): TABLE: { "row_generators": [ {"name": "gen1", "columns_assigned": [COLUMN_1, COLUMN_2]}, - { "name": REMAINING_GEN, "columns_assigned": [COLUMN_3] }, + {"name": REMAINING_GEN, "columns_assigned": [COLUMN_3]}, ] } } @@ -625,24 +723,28 @@ def test_unmerge_columns(self): self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) def test_old_generators_remain(self): - """ Test that we can set one generator and keep an old one. """ + """Test that we can set one generator and keep an old one.""" config = { "tables": { "string": { - "row_generators": [{ - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', - }, - }] + "row_generators": [ + { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', + }, + } + ] } }, - "src-stats": [{ - "name": "auto__string", - "query": 'SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string', - }] + "src-stats": [ + { + "name": "auto__string", + "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + } + ], } with self._get_cmd(config) as gc: TABLE = "model" @@ -663,18 +765,23 @@ def test_old_generators_remain(self): row_gen = row_gens[0] self.assertEqual(row_gen["name"], "dist_gen.normal") self.assertListEqual(row_gen["columns_assigned"], ["frequency"]) - self.assertDictEqual(row_gen["kwargs"], { - "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', - }) + self.assertDictEqual( + row_gen["kwargs"], + { + "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__string") self.assertEqual( gc.config["src-stats"][0]["query"], "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", ) - + def test_aggregate_queries_merge(self): """ Test that we can set a generator that requires select aggregate clauses @@ -683,20 +790,24 @@ def test_aggregate_queries_merge(self): config = { "tables": { "string": { - "row_generators": [{ - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }, - }] + "row_generators": [ + { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', + }, + } + ] } }, - "src-stats": [{ - "name": "auto__string", - "query": 'SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string', - }] + "src-stats": [ + { + "name": "auto__string", + "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + } + ], } with self._get_cmd(copy.deepcopy(config)) as gc: COLUMN = "position" @@ -706,7 +817,9 @@ def test_aggregate_queries_merge(self): proposals = gc.get_proposals() gc.do_set(str(proposals[f"{GENERATOR}"][0])) gc.do_quit("") - row_gens: list[dict[str,any]] = gc.config["tables"]["string"]["row_generators"] + row_gens: list[dict[str, any]] = gc.config["tables"]["string"][ + "row_generators" + ] self.assertEqual(len(row_gens), 2) if row_gens[0]["name"] == GENERATOR: row_gen0 = row_gens[0] @@ -717,28 +830,41 @@ def test_aggregate_queries_merge(self): self.assertEqual(row_gen0["name"], GENERATOR) self.assertEqual(row_gen1["name"], "dist_gen.normal") self.assertListEqual(row_gen0["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen0["kwargs"], { - "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', - }) + self.assertDictEqual( + row_gen0["kwargs"], + { + "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', + "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', + }, + ) self.assertListEqual(row_gen1["columns_assigned"], ["frequency"]) - self.assertDictEqual(row_gen1["kwargs"], { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }) + self.assertDictEqual( + row_gen1["kwargs"], + { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") - select_match = re.match(r'SELECT (.*) FROM string', gc.config["src-stats"][0]["query"]) - self.assertIsNotNone(select_match, "src_stats[0].query is not an aggregate select") - self.assertSetEqual(set(select_match.group(1).split(", ")), { - "AVG(frequency) AS mean__frequency", - "STDDEV(frequency) AS stddev__frequency", - f"AVG({COLUMN}) AS mean__{COLUMN}", - f"STDDEV({COLUMN}) AS stddev__{COLUMN}", - }) + select_match = re.match( + r"SELECT (.*) FROM string", gc.config["src-stats"][0]["query"] + ) + self.assertIsNotNone( + select_match, "src_stats[0].query is not an aggregate select" + ) + self.assertSetEqual( + set(select_match.group(1).split(", ")), + { + "AVG(frequency) AS mean__frequency", + "STDDEV(frequency) AS stddev__frequency", + f"AVG({COLUMN}) AS mean__{COLUMN}", + f"STDDEV({COLUMN}) AS stddev__{COLUMN}", + }, + ) def test_next_completion(self): - """ Test tab completion for the next command. """ + """Test tab completion for the next command.""" with self._get_cmd({}) as gc: self.assertSetEqual( set(gc.complete_next("m", "next m", 5, 6)), @@ -756,7 +882,9 @@ def test_next_completion(self): set(gc.complete_next("string.p", "next string.p", 5, 12)), {"string.position"}, ) - self.assertListEqual(gc.complete_next("string.q", "next string.q", 5, 12), []) + self.assertListEqual( + gc.complete_next("string.q", "next string.q", 5, 12), [] + ) self.assertListEqual(gc.complete_next("ww", "next ww", 5, 7), []) def test_compare_reports_privacy(self): @@ -799,10 +927,12 @@ def test_existing_configuration_remains(self): "primary_private": True, } }, - "src-stats": [{ - "name": "kraken", - "query": 'SELECT MAX(frequency) AS max_frequency FROM string', - }] + "src-stats": [ + { + "name": "kraken", + "query": "SELECT MAX(frequency) AS max_frequency FROM string", + } + ], } with self._get_cmd(config) as gc: COLUMN = "position" @@ -812,15 +942,12 @@ def test_existing_configuration_remains(self): proposals = gc.get_proposals() gc.do_set(str(proposals[f"{GENERATOR}"][0])) gc.do_quit("") - src_stats = { - stat["name"]: stat["query"] - for stat in gc.config["src-stats"] - } + src_stats = {stat["name"]: stat["query"] for stat in gc.config["src-stats"]} self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) self.assertTrue(gc.config["tables"]["string"]["primary_private"]) def test_empty_tables_are_not_configured(self): - """ Test that tables marked as empty are not configured. """ + """Test that tables marked as empty are not configured.""" config = { "tables": { "string": { @@ -830,13 +957,14 @@ def test_empty_tables_are_not_configured(self): } with self._get_cmd(copy.deepcopy(config)) as gc: gc.do_tables("") - table_names = { m[1][0] for m in gc.messages } + table_names = {m[1][0] for m in gc.messages} self.assertIn("model", table_names) self.assertNotIn("string", table_names) class GeneratorsOutputTests(GeneratesDBTestCase): - """ Testing choice generation. """ + """Testing choice generation.""" + dump_file_path = "choice.sql" database_name = "numbers" schema_name = "public" @@ -845,7 +973,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_create_with_sampled_choice(self): - """ Test that suppression works for choice and zipf_choice. """ + """Test that suppression works for choice and zipf_choice.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") @@ -869,7 +997,9 @@ def test_create_with_sampled_choice(self): self.assertIn("dist_gen.zipf_choice [sampled]", proposals) self.assertIn("dist_gen.choice [sampled and suppressed]", proposals) self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) - gc.do_set(str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0])) + gc.do_set( + str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0]) + ) gc.do_next("number_table.three") gc.reset() gc.do_propose("") @@ -899,7 +1029,7 @@ def test_create_with_sampled_choice(self): self.assertSetEqual(threes, {1, 2, 3, 4, 5}) def test_create_with_choice(self): - """ Smoke test normal choice works. """ + """Smoke test normal choice works.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") @@ -927,7 +1057,7 @@ def test_create_with_choice(self): self.assertSetEqual(twos, {1, 2, 3, 4, 5}) def test_create_with_weighted_choice(self): - """ Smoke test weighted choice. """ + """Smoke test weighted choice.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") @@ -936,12 +1066,16 @@ def test_create_with_weighted_choice(self): proposals = gc.get_proposals() self.assertIn("dist_gen.weighted_choice", proposals) self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) + self.assertIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] self.assertSubset(set(prop[2]), {"1", "4"}) gc.reset() gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" + col_heading = ( + f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" + ) self.assertIn(col_heading, set(gc.columns.keys())) self.assertSubset(set(gc.columns[col_heading]), {1, 4}) gc.do_set(str(prop[0])) @@ -951,7 +1085,9 @@ def test_create_with_weighted_choice(self): proposals = gc.get_proposals() self.assertIn("dist_gen.weighted_choice", proposals) self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) + self.assertIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) prop = proposals["dist_gen.weighted_choice"] self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.reset() @@ -966,7 +1102,9 @@ def test_create_with_weighted_choice(self): proposals = gc.get_proposals() self.assertIn("dist_gen.weighted_choice", proposals) self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertNotIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) + self.assertNotIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) prop = proposals["dist_gen.weighted_choice [sampled]"] self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.do_compare(str(prop[0])) @@ -993,10 +1131,12 @@ def test_create_with_weighted_choice(self): class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): - """ MissingnessCmd but mocked """ + """MissingnessCmd but mocked""" + class ConfigureMissingnessTests(RequiresDBTestCase): - """ Testing configure-missing. """ + """Testing configure-missing.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -1005,34 +1145,50 @@ def _get_cmd(self, config) -> TestMissingnessCmd: return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) def test_set_missingness_to_sampled(self): - """ Test that we can set one table to sampled missingness. """ + """Test that we can set one table to sampled missingness.""" with self._get_cmd({}) as mc: TABLE = "signature_model" mc.do_next(TABLE) mc.do_counts("") - self.assertListEqual(mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})]) - self.assertListEqual(mc.rows, [['player_id', 3], ['based_on', 2]]) + self.assertListEqual( + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})] + ) + self.assertListEqual(mc.rows, [["player_id", 3], ["based_on", 2]]) mc.do_sampled("") mc.do_quit("") self.assertDictEqual( mc.config, - { "tables": {TABLE: {"missingness_generators": [{ - "columns": ["player_id", "based_on"], - "kwargs": {"patterns": 'SRC_STATS["missing_auto__signature_model__0"]'}, - "name": "column_presence.sampled", - }]}}, - "src-stats": [{ - "name": "missing_auto__signature_model__0", - "query": ("SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" + { + "tables": { + TABLE: { + "missingness_generators": [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]' + }, + "name": "column_presence.sampled", + } + ] + } + }, + "src-stats": [ + { + "name": "missing_auto__signature_model__0", + "query": ( + "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null") - }] - } + " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" + ), + } + ], + }, ) class ConfigureMissingnessTests(GeneratesDBTestCase): - """ Testing configure-missing with generation. """ + """Testing configure-missing with generation.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -1041,7 +1197,7 @@ def _get_cmd(self, config) -> TestMissingnessCmd: return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) def test_create_with_missingness(self): - """ Test that we can sample real missingness and reproduce it. """ + """Test that we can sample real missingness and reproduce it.""" random.seed(45) # Configure the missingness table_name = "signature_model" @@ -1065,7 +1221,8 @@ def test_create_with_missingness(self): class GeneratorTests(GeneratesDBTestCase): - """ Testing configure-generators with generation. """ + """Testing configure-generators with generation.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -1074,7 +1231,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_set_null(self): - """ Test that we can sample real missingness and reproduce it. """ + """Test that we can sample real missingness and reproduce it.""" with self._get_cmd({}) as gc: gc.do_next("string.position") gc.do_set("dist_gen.constant") @@ -1091,7 +1248,9 @@ def test_set_null(self): gc.do_next("signature_model.based_on") gc.do_set("dist_gen.constant") # we have got to the end of the columns, but shouldn't have any errors - self.assertListEqual(gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})]) + self.assertListEqual( + gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})] + ) gc.reset() gc.do_quit("") config = gc.config @@ -1116,7 +1275,7 @@ def test_set_null(self): self.assertEqual(count, 3) def test_dist_gen_sampled_produces_ordered_src_stats(self): - """ Tests that choosing a sampled choice generator produces ordered src stats """ + """Tests that choosing a sampled choice generator produces ordered src stats""" with self._get_cmd({}) as gc: gc.do_next("signature_model.player_id") gc.do_set("dist_gen.zipf_choice [sampled]") @@ -1127,26 +1286,24 @@ def test_dist_gen_sampled_produces_ordered_src_stats(self): self.set_configuration(config) src_stats = self.get_src_stats(config) player_ids = [ - s["value"] - for s in src_stats["auto__signature_model__player_id"]["results"] + s["value"] for s in src_stats["auto__signature_model__player_id"]["results"] ] self.assertListEqual(player_ids, [2, 3, 1]) based_ons = [ - s["value"] - for s in src_stats["auto__signature_model__based_on"]["results"] + s["value"] for s in src_stats["auto__signature_model__based_on"]["results"] ] self.assertListEqual(based_ons, [1, 3, 2]) def assertAreTruncatedTo(self, xs, length): - maxlen = 0 - for x in xs: - newlen = len(x.strip("'\"")) - self.assertLessEqual(newlen, length) - maxlen = max(maxlen, newlen) - self.assertEqual(maxlen, length) + maxlen = 0 + for x in xs: + newlen = len(x.strip("'\"")) + self.assertLessEqual(newlen, length) + maxlen = max(maxlen, newlen) + self.assertEqual(maxlen, length) def test_varchar_ns_are_truncated(self): - """ Tests that mimesis generators for VARCHAR(N) truncate to N characters """ + """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" GENERATOR = "generic.text.quote" TABLE = "signature_model" COLUMN = "name" @@ -1176,9 +1333,9 @@ def test_varchar_ns_are_truncated(self): @dataclass class Stat: - n: int=0 - x: float=0 - x2: float=0 + n: int = 0 + x: float = 0 + x2: float = 0 def add(self, x: float) -> None: self.n += 1 @@ -1193,14 +1350,14 @@ def x_mean(self) -> float: def x_var(self) -> float: x = self.x - return (self.x2 - x*x/self.n)/(self.n - 1) + return (self.x2 - x * x / self.n) / (self.n - 1) @dataclass class Correlation(Stat): - y: float=0 - y2: float=0 - xy: float=0 + y: float = 0 + y2: float = 0 + xy: float = 0 def add(self, x: float, y: float) -> None: self.n += 1 @@ -1215,14 +1372,15 @@ def y_mean(self) -> float: def y_var(self) -> float: y = self.y - return (self.y2 - y*y/self.n)/(self.n - 1) + return (self.y2 - y * y / self.n) / (self.n - 1) def covar(self) -> float: - return (self.xy - self.x*self.y/self.n)/(self.n - 1) + return (self.xy - self.x * self.y / self.n) / (self.n - 1) class NullPartitionedTests(GeneratesDBTestCase): - """ Testing null-partitioned grouped multivariate generation. """ + """Testing null-partitioned grouped multivariate generation.""" + dump_file_path = "eav.sql" database_name = "eav" schema_name = "public" @@ -1236,7 +1394,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_create_with_null_partitioned_grouped_multivariate(self): - """ Test EAV for all columns. """ + """Test EAV for all columns.""" table_name = "measurement" generate_count = 800 with self._get_cmd({}) as gc: @@ -1287,9 +1445,9 @@ def test_create_with_null_partitioned_grouped_multivariate(self): # yes or no self.assertIsNone(row.first_value) self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'yes', 'no'}) + self.assertIn(row.third_value, {"yes", "no"}) one_count += 1 - if row.third_value == 'yes': + if row.third_value == "yes": one_yes_count += 1 elif row.type == 2: # positive correlation around 1.4, 1.8 @@ -1310,43 +1468,57 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertIsNone(row.third_value) four.add(row.first_value, row.second_value) elif row.type == 5: - self.assertIn(row.third_value, {'fish', 'fowl'}) + self.assertIn(row.third_value, {"fish", "fowl"}) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) - if row.third_value == 'fish': + if row.third_value == "fish": # mean 8.1 and sd 0.755 fish.add(row.first_value) else: # mean 11.2 and sd 1.114 fowl.add(row.first_value) # type 1 - self.assertAlmostEqual(one_count, generate_count * 5 / 20, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) # about 40% are yes - self.assertAlmostEqual(one_yes_count / one_count, 0.4, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_yes_count / one_count, 0.4, delta=generate_count * 0.4 + ) # type 2 - self.assertAlmostEqual(two.count(), generate_count * 3 / 20, delta=generate_count * 0.5) + self.assertAlmostEqual( + two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) self.assertAlmostEqual(two.x_var(), 0.21, delta=0.4) self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) self.assertAlmostEqual(two.y_var(), 0.07, delta=0.1) self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) # type 3 - self.assertAlmostEqual(three.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(two.covar(), -0.5, delta=0.5) # type 4 - self.assertAlmostEqual(four.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) # type 5/fish - self.assertAlmostEqual(fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) # type 5/fowl - self.assertAlmostEqual(fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): - """ Test EAV for all columns with sampled and suppressed generation. """ + """Test EAV for all columns with sampled and suppressed generation.""" table_name = "measurement" table2_name = "observation" generate_count = 800 @@ -1360,8 +1532,13 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): proposals = gc.get_proposals() self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn("null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + proposals, + ) + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) self.assertIn(dist_to_choose, proposals) prop = proposals[dist_to_choose] gc.reset() @@ -1377,7 +1554,9 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): gc.reset() gc.do_propose("") proposals = gc.get_proposals() - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) prop = proposals[dist_to_choose] gc.do_set(str(prop[0])) gc.do_quit("") @@ -1409,15 +1588,15 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): # yes or no self.assertIsNone(row.first_value) self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'yes', 'no'}) - if row.third_value == 'yes': + self.assertIn(row.third_value, {"yes", "no"}) + if row.third_value == "yes": one_yes_count += 1 one_count += 1 elif row.type == 5: - self.assertIn(row.third_value, {'fish', 'fowl'}) + self.assertIn(row.third_value, {"fish", "fowl"}) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) - if row.third_value == 'fish': + if row.third_value == "fish": # mean 8.1 and sd 0.755 fish.add(row.first_value) else: @@ -1427,15 +1606,23 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.assertEqual(len(types), 4) self.assertSubset({1, 5}, types) # type 1 - self.assertAlmostEqual(one_count, generate_count * 5 / 11, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) # about 40% are yes - self.assertAlmostEqual(one_yes_count / one_count, 0.4, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_yes_count / one_count, 0.4, delta=generate_count * 0.4 + ) # type 5/fish - self.assertAlmostEqual(fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2) + self.assertAlmostEqual( + fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) # type 5/fowl - self.assertAlmostEqual(fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2) + self.assertAlmostEqual( + fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) stmt = select(self.metadata.tables[table2_name]) @@ -1446,38 +1633,52 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.assertEqual(row.type, 1) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'ham', 'eggs'}) + self.assertIn(row.third_value, {"ham", "eggs"}) firsts.add(row.first_value) self.assertEqual(firsts.count(), 800) - self.assertAlmostEqual(firsts.x_mean(), 1.3, delta = generate_count * 0.3) + self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) class NonInteractiveTests(RequiresDBTestCase): """ Test the --spec SPEC_FILE option of configure-generators """ + dump_file_path = "eav.sql" database_name = "eav" schema_name = "public" @patch("datafaker.interactive.Path") - @patch("datafaker.interactive.csv.reader", return_value=iter([ - ["observation", "type", "dist_gen.weighted_choice"], - ["observation", "first_value", "dist_gen.weighted_choice"], - ["observation", "third_value", "dist_gen.weighted_choice"], - ])) - def test_non_interactive_configure_generators(self, mock_csv_reader: MagicMock, mock_path: MagicMock): + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + ["observation", "type", "dist_gen.weighted_choice"], + ["observation", "first_value", "dist_gen.weighted_choice"], + ["observation", "third_value", "dist_gen.weighted_choice"], + ] + ), + ) + def test_non_interactive_configure_generators( + self, mock_csv_reader: MagicMock, mock_path: MagicMock + ): """ test that we can set generators from a CSV file """ config = {} spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators(self.dsn, self.schema_name, self.metadata, config, spec_csv) + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) row_gens = { f"{table}{sorted(rg['columns_assigned'])}": rg["name"] for table, tables in config.get("tables", {}).items() for rg in tables.get("row_generators", []) } self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") - self.assertEqual(row_gens["observation['first_value']"], "dist_gen.weighted_choice") - self.assertEqual(row_gens["observation['third_value']"], "dist_gen.weighted_choice") + self.assertEqual( + row_gens["observation['first_value']"], "dist_gen.weighted_choice" + ) + self.assertEqual( + row_gens["observation['third_value']"], "dist_gen.weighted_choice" + ) diff --git a/tests/test_main.py b/tests/test_main.py index c3652a2d..a37eaf4f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -21,7 +21,13 @@ class TestCLI(DatafakerTestCase): @patch("datafaker.main.dict_to_metadata") @patch("datafaker.main.load_metadata_config") @patch("datafaker.main.create_db_vocab") - def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_meta: MagicMock, mock_config: MagicMock) -> None: + def test_create_vocab( + self, + mock_create: MagicMock, + mock_mdict: MagicMock, + mock_meta: MagicMock, + mock_config: MagicMock, + ) -> None: """Test the create-vocab sub-command.""" result = runner.invoke( app, @@ -31,7 +37,9 @@ def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_ catch_exceptions=False, ) - mock_create.assert_called_once_with(mock_meta.return_value, mock_mdict.return_value, mock_config.return_value) + mock_create.assert_called_once_with( + mock_meta.return_value, mock_mdict.return_value, mock_config.return_value + ) self.assertSuccess(result) @patch("datafaker.main.read_config_file") @@ -159,10 +167,13 @@ def test_create_generators_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke(app, [ - "create-generators", - force_option, - ]) + result: Result = runner.invoke( + app, + [ + "create-generators", + force_option, + ], + ) mock_make.assert_called_once_with( mock_load_meta.return_value, @@ -335,11 +346,14 @@ def test_make_tables_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke(app, [ - "make-tables", - force_option, - "--orm-file=tests/examples/example_orm.yaml", - ]) + result: Result = runner.invoke( + app, + [ + "make-tables", + force_option, + "--orm-file=tests/examples/example_orm.yaml", + ], + ) mock_make_tables.assert_called_once_with( mock_get_settings.return_value.src_dsn, @@ -359,7 +373,11 @@ def test_make_tables_with_force_enabled( @patch("datafaker.main.get_settings") @patch("datafaker.main.load_metadata", side_effect=["ms"]) def test_make_stats( - self, _lm: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock + self, + _lm: MagicMock, + mock_get_settings: MagicMock, + mock_make: MagicMock, + mock_path: MagicMock, ) -> None: """Test the make-stats sub-command.""" example_conf_path = "tests/examples/example_config.yaml" @@ -379,7 +397,9 @@ def test_make_stats( self.assertSuccess(result) with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - mock_make.assert_called_once_with(get_test_settings().src_dsn, config, "ms", None) + mock_make.assert_called_once_with( + get_test_settings().src_dsn, config, "ms", None + ) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" ) @@ -434,7 +454,11 @@ def test_make_stats_errors_if_no_src_dsn(self, mock_logger: MagicMock) -> None: @patch("datafaker.main.get_settings") @patch("datafaker.main.load_metadata") def test_make_stats_with_force_enabled( - self, mock_meta: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock + self, + mock_meta: MagicMock, + mock_get_settings: MagicMock, + mock_make: MagicMock, + mock_path: MagicMock, ) -> None: """Tests that the make-stats command overwrite files when instructed.""" test_config_file: str = "tests/examples/example_config.yaml" @@ -461,7 +485,10 @@ def test_make_stats_with_force_enabled( ) mock_make.assert_called_once_with( - test_settings.src_dsn, config_file_content, mock_meta.return_value, None + test_settings.src_dsn, + config_file_content, + mock_meta.return_value, + None, ) mock_path.return_value.write_text.assert_called_once_with( "some_stat: 0\n", encoding="utf-8" @@ -507,7 +534,9 @@ def test_remove_data( catch_exceptions=False, ) self.assertEqual(0, result.exit_code) - mock_remove.assert_called_once_with(mock_meta.return_value, mock_config.return_value) + mock_remove.assert_called_once_with( + mock_meta.return_value, mock_config.return_value + ) @patch("datafaker.main.read_config_file") @patch("datafaker.main.remove_db_vocab") @@ -528,12 +557,18 @@ def test_remove_vocab( ) self.assertEqual(0, result.exit_code) mock_read_config.assert_called_once_with("config.yaml") - mock_remove.assert_called_once_with(mock_d2m.return_value, mock_load_metadata.return_value, mock_read_config.return_value) + mock_remove.assert_called_once_with( + mock_d2m.return_value, + mock_load_metadata.return_value, + mock_read_config.return_value, + ) @patch("datafaker.main.remove_db_tables") @patch("datafaker.main.load_metadata_for_output") @patch("datafaker.main.read_config_file") - def test_remove_tables(self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock) -> None: + def test_remove_tables( + self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock + ) -> None: """Test the remove-tables command.""" result = runner.invoke( app, diff --git a/tests/test_make.py b/tests/test_make.py index 6100aa77..f43588ac 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -9,37 +9,38 @@ from sqlalchemy.dialects.mysql.types import INTEGER from sqlalchemy.dialects.postgresql import UUID -from datafaker.make import ( - _get_provider_for_column, - make_src_stats, -) -from tests.utils import RequiresDBTestCase, GeneratesDBTestCase +from datafaker.make import _get_provider_for_column, make_src_stats +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase class TestMakeGenerators(GeneratesDBTestCase): """Test the make_table_generators function.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def test_make_table_generators(self) -> None: - """ Check that we can make a generators file. """ + """Check that we can make a generators file.""" config = { "tables": { "player": { - "row_generators": [{ - "name": "dist_gen.constant", - "kwargs": { - "value": '"Cave"', + "row_generators": [ + { + "name": "dist_gen.constant", + "kwargs": { + "value": '"Cave"', + }, + "columns_assigned": "given_name", }, - "columns_assigned": "given_name", - }, { - "name": "dist_gen.constant", - "kwargs": { - "value": '"Johnson"', + { + "name": "dist_gen.constant", + "kwargs": { + "value": '"Johnson"', + }, + "columns_assigned": "family_name", }, - "columns_assigned": "family_name", - }], + ], }, }, } @@ -96,7 +97,7 @@ def test_get_provider_for_column(self) -> None: ) self.assertEqual( generator_arguments, - { "length": "100" }, + {"length": "100"}, ) # UUID @@ -149,12 +150,15 @@ def check_make_stats_output(self, src_stats: dict) -> None: count_names = src_stats["count_names"]["results"] count_names.sort(key=lambda c: c["name"]) - self.assertListEqual(count_names, [ - {"num": 1, "name": "Miranda Rando-Generata"}, - {"num": 997, "name": "Randy Random"}, - {"num": 1, "name": "Testfried Testermann"}, - {"num": 1, "name": "Veronica Fyre"}, - ]) + self.assertListEqual( + count_names, + [ + {"num": 1, "name": "Miranda Rando-Generata"}, + {"num": 997, "name": "Randy Random"}, + {"num": 1, "name": "Testfried Testermann"}, + {"num": 1, "name": "Veronica Fyre"}, + ], + ) avg_person_id = src_stats["avg_person_id"]["results"] self.assertEqual(len(avg_person_id), 1) diff --git a/tests/test_providers.py b/tests/test_providers.py index 9cc03c5e..aedb693a 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.declarative import declarative_base from datafaker import providers -from tests.utils import RequiresDBTestCase, DatafakerTestCase +from tests.utils import DatafakerTestCase, RequiresDBTestCase # pylint: disable=invalid-name Base = declarative_base() @@ -37,6 +37,7 @@ def test_bytes(self) -> None: class ColumnValueProviderTestCase(RequiresDBTestCase): """Tests for the ColumnValueProvider class.""" + dump_file_path = "providers.dump" def setUp(self) -> None: diff --git a/tests/test_remove.py b/tests/test_remove.py index 660d6cb8..bfbb787d 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -1,25 +1,25 @@ """Tests for the remove module.""" from unittest.mock import MagicMock, patch +from sqlalchemy import func, inspect, select + from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict from datafaker.settings import Settings -from sqlalchemy import func, inspect, select from tests.utils import RequiresDBTestCase class RemoveThingsTestCase(RequiresDBTestCase): - """ Tests for ``remove-`` commands. """ + """Tests for ``remove-`` commands.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def count_rows(self, connection, table_name: str) -> int | None: - return connection.execute(select( - func.count() - ).select_from( - self.metadata.tables[table_name] - )).scalar() + return connection.execute( + select(func.count()).select_from(self.metadata.tables[table_name]) + ).scalar() @patch("datafaker.remove.get_settings") def test_remove_data(self, mock_get_settings: MagicMock): @@ -28,12 +28,15 @@ def test_remove_data(self, mock_get_settings: MagicMock): dst_dsn=self.dsn, _env_file=None, ) - remove_db_data(self.metadata, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_data( + self.metadata, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) with self.engine.connect() as conn: self.assertGreater(self.count_rows(conn, "manufacturer"), 0) self.assertGreater(self.count_rows(conn, "model"), 0) @@ -43,19 +46,22 @@ def test_remove_data(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-data raises if dst DSN is missing. """ + """Test that remove-data raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: - remove_db_data(self.metadata, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_data( + self.metadata, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) self.assertEqual( context_manager.exception.args[0], "Missing destination database settings" ) @@ -70,8 +76,8 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) config = { "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, } } remove_db_data(self.metadata, config) @@ -85,7 +91,7 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-vocab raises if dst DSN is missing. """ + """Test that remove-vocab raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, @@ -93,12 +99,16 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: ) with self.assertRaises(AssertionError) as context_manager: meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) - remove_db_vocab(self.metadata, meta_dict, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_vocab( + self.metadata, + meta_dict, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) self.assertEqual( context_manager.exception.args[0], "Missing destination database settings" ) @@ -120,7 +130,7 @@ def test_remove_tables(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-vocab raises if dst DSN is missing. """ + """Test that remove-vocab raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 41a77474..81e9eeac 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -40,6 +40,7 @@ class UniqueGeneratorTestCase(RequiresDBTestCase): and b which are boolean, and c which is a text column. There is a joint unique constraint on a and b, and a separate unique constraint on c. """ + dump_file_path = "unique_generator.dump" def setUp(self) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 0eca2b11..2640a9e1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import os import sys from pathlib import Path -from unittest.mock import patch, MagicMock, call +from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert from sqlalchemy.orm import declarative_base @@ -13,7 +13,7 @@ import_file, read_config_file, ) -from tests.utils import RequiresDBTestCase, DatafakerTestCase +from tests.utils import DatafakerTestCase, RequiresDBTestCase # pylint: disable=invalid-name Base = declarative_base() @@ -85,7 +85,9 @@ def test_download_table(self) -> None: conn.execute(insert(MyTable).values({"id": 1})) conn.commit() - download_table(MyTable.__table__, self.engine, self.mytable_file_path, compress=False) + download_table( + MyTable.__table__, self.engine, self.mytable_file_path, compress=False + ) # The .strip() gets rid of any possible empty lines at the end of the file. with Path("../examples/expected.yaml").open(encoding="utf-8") as yamlfile: @@ -108,124 +110,219 @@ def test_warns_of_invalid_config(self) -> None: "The config file is invalid: %s", "'a' is not of type 'integer'" ) + class TestUtils(DatafakerTestCase): - """ Miscellaneous tests. """ + """Miscellaneous tests.""" + def test_generators_require_stats(self) -> None: - """ Test that we can tell if a configuration requires SRC_STATS or not. """ - self.assertTrue(generators_require_stats({ - "object_instantiation": { - "mygen": {"name": "MyGen", "kwargs": {"a": '1 + SRC_STATS["my"]["results"][0]'}} - } - })) - self.assertTrue(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "kwargs": {"a": '[None] + SRC_STATS["my"]["results"]'}, - }] - })) - self.assertTrue(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "args": ['(SRC_STATS["my"]["results"])'], - }] - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "missingness_generators":[{ - "name": "msg", - "kwargs": {"a": '[SRC_STATS["my"], SRC_STATS["theirs"]]'}, - "columns_assigned": ["a"], - }] + """Test that we can tell if a configuration requires SRC_STATS or not.""" + self.assertTrue( + generators_require_stats( + { + "object_instantiation": { + "mygen": { + "name": "MyGen", + "kwargs": {"a": '1 + SRC_STATS["my"]["results"][0]'}, + } + } } - } - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "kwargs": {"a": 'SRC_STATS["ifu"]["results"]'}, - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "kwargs": {"a": '[None] + SRC_STATS["my"]["results"]'}, + } + ] } - } - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['SRC_STATS'], - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "args": ['(SRC_STATS["my"]["results"])'], + } + ] } - } - })) - self.assertFalse(generators_require_stats({ - "object_instantiation": { - "mygen": {"name": "MyGen", "kwargs": {"a": 1}} - } - })) - self.assertFalse(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "kwargs": {"a": '[None]'}, - }] - })) - self.assertFalse(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "args": ['(SRC_STATS_["my"]["results"])'], - }] - })) - self.assertFalse(generators_require_stats({ - "missingness_generators": [{ - "name": "msg", - "kwargs": {"a": '"SRC_STATS"'}, - "columns_assigned": ["a"], - }] - })) - self.assertFalse(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "kwargs": {"a": 'SRC_STAT["ifu"]["results"]'}, - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "missingness_generators": [ + { + "name": "msg", + "kwargs": { + "a": '[SRC_STATS["my"], SRC_STATS["theirs"]]' + }, + "columns_assigned": ["a"], + } + ] + } + } } - } - })) - self.assertFalse(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['SRC_STATSS'], - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "kwargs": {"a": 'SRC_STATS["ifu"]["results"]'}, + "columns_assigned": ["a"], + } + ] + } + } } - } - })) + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["SRC_STATS"], + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "object_instantiation": { + "mygen": {"name": "MyGen", "kwargs": {"a": 1}} + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "kwargs": {"a": "[None]"}, + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "args": ['(SRC_STATS_["my"]["results"])'], + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "missingness_generators": [ + { + "name": "msg", + "kwargs": {"a": '"SRC_STATS"'}, + "columns_assigned": ["a"], + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "kwargs": {"a": 'SRC_STAT["ifu"]["results"]'}, + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["SRC_STATSS"], + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) @patch("datafaker.utils.logger") def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): - generators_require_stats({ - "story_generators": [ - {"name": "my_story_gen", "kwargs": {"b": "'unclosed"}} - ], - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['1 2'], - "columns_assigned": ["a"], - }] - } + generators_require_stats( + { + "story_generators": [ + {"name": "my_story_gen", "kwargs": {"b": "'unclosed"}} + ], + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["1 2"], + "columns_assigned": ["a"], + } + ] + } + }, } - }) - logger.error.assert_has_calls([ - call("Syntax error in argument %s of %s: %s\n%s\n%s", "b", "story_generators[0]", "unterminated string literal (detected at line 1)", "'unclosed", " ^"), - call("Syntax error in argument %d of %s: %s\n%s\n%s", 1, "tables.things.row_generators[0]", "invalid syntax", "1 2", " ^"), - ]) + ) + logger.error.assert_has_calls( + [ + call( + "Syntax error in argument %s of %s: %s\n%s\n%s", + "b", + "story_generators[0]", + "unterminated string literal (detected at line 1)", + "'unclosed", + " ^", + ), + call( + "Syntax error in argument %d of %s: %s\n%s\n%s", + 1, + "tables.things.row_generators[0]", + "invalid syntax", + "1 2", + " ^", + ), + ] + ) diff --git a/tests/utils.py b/tests/utils.py index 08850f85..a6eb5931 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,25 +1,26 @@ """Utilities for testing.""" import asyncio -from functools import lru_cache import os -from pathlib import Path import shutil -from sqlalchemy.schema import MetaData -from subprocess import run -import testing.postgresql import traceback +from functools import lru_cache +from pathlib import Path +from subprocess import run +from tempfile import mkstemp from typing import Any from unittest import TestCase, skipUnless -import yaml +import testing.postgresql +import yaml from sqlalchemy import MetaData -from tempfile import mkstemp +from sqlalchemy.schema import MetaData from datafaker import settings from datafaker.create import create_db_data_into -from datafaker.make import make_tables_file, make_src_stats, make_table_generators +from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import import_file, sorted_non_vocabulary_tables, create_db_engine +from datafaker.utils import create_db_engine, import_file, sorted_non_vocabulary_tables + class SysExit(Exception): """To force the function to exit as sys.exit() would.""" @@ -68,10 +69,10 @@ def assertFailure(self, result: Any) -> None: # pylint: disable=invalid-name self.assertReturnCode(result, 1) def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-name - """ Assert that the result has no exception. """ + """Assert that the result has no exception.""" if result.exception is None: return - self.fail(''.join(traceback.format_exception(result.exception))) + self.fail("".join(traceback.format_exception(result.exception))) def assertSubset(self, set1, set2, msg=None): """Assert a set is a (non-strict) subset. @@ -85,22 +86,23 @@ def assertSubset(self, set1, set2, msg=None): try: difference = set1.difference(set2) except TypeError as e: - self.fail('invalid type when attempting set difference: %s' % e) + self.fail("invalid type when attempting set difference: %s" % e) except AttributeError as e: - self.fail('first argument does not support set difference: %s' % e) + self.fail("first argument does not support set difference: %s" % e) if not difference: return lines = [] if difference: - lines.append('Items in the first set but not the second:') + lines.append("Items in the first set but not the second:") for item in difference: lines.append(repr(item)) - standardMsg = '\n'.join(lines) + standardMsg = "\n".join(lines) self.fail(self._formatMessage(msg, standardMsg)) + @skipUnless(shutil.which("psql"), "need to find 'psql': install PostgreSQL to enable") class RequiresDBTestCase(DatafakerTestCase): """ @@ -112,6 +114,7 @@ class RequiresDBTestCase(DatafakerTestCase): to get an engine to access the database and self.metadata to get metadata reflected from that engine. """ + schema_name = None use_asyncio = False examples_dir = "tests/examples" @@ -201,13 +204,15 @@ def get_src_stats(self, config) -> dict[str, any]: make_src_stats(self.dsn, config, self.metadata, self.schema_name) ) loop.close() - (self.stats_fd, self.stats_file_path) = mkstemp(".yaml", "src_stats_", text=True) + (self.stats_fd, self.stats_file_path) = mkstemp( + ".yaml", "src_stats_", text=True + ) with os.fdopen(self.stats_fd, "w", encoding="utf-8") as stats_fh: stats_fh.write(yaml.dump(src_stats)) return src_stats def create_generators(self, config) -> None: - """ ``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py`` """ + """``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py``""" datafaker_content = make_table_generators( self.metadata, config, @@ -220,12 +225,12 @@ def create_generators(self, config) -> None: datafaker_fh.write(datafaker_content) def remove_data(self, config): - """ Remove source data from the DB. """ + """Remove source data from the DB.""" # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) def create_data(self, config, num_passes=1): - """ Create fake data in the DB. """ + """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) table_generator_dict = datafaker_module.table_generator_dict From bb14ced8f721077c0a548f3de569a39660a5aa2f Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 23:02:28 +0100 Subject: [PATCH 02/44] Fixed a test --- tests/test_interactive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 94872845..8bfb7bc0 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1514,8 +1514,8 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertAlmostEqual( fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): """Test EAV for all columns with sampled and suppressed generation.""" From cda6164fbfe606f92b028520be0051e51b5c55de Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 23:31:03 +0100 Subject: [PATCH 03/44] base and create mypy fixes --- datafaker/base.py | 72 +++++++++++++++++++++--------------- datafaker/create.py | 26 ++++++++----- datafaker/templates/df.py.j2 | 2 +- tests/test_interactive.py | 4 +- 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 56315a40..0a3cf41f 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Callable, Generator, TypeVar import numpy as np import yaml @@ -24,13 +24,16 @@ ) +_T = TypeVar("_T") + + @functools.cache -def zipf_weights(size): +def zipf_weights(size: int) -> list[float]: total = sum(map(lambda n: 1 / n, range(1, size + 1))) return [1 / (n * total) for n in range(1, size + 1)] -def merge_with_constants(xs: list, constants_at: dict[int, any]): +def merge_with_constants(xs: list[_T], constants_at: dict[int, _T]) -> Generator[_T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. :param constants_at: A map of indices to objects that must be placed at @@ -59,41 +62,41 @@ def merge_with_constants(xs: list, constants_at: dict[int, any]): class NothingToGenerateException(Exception): - def __init__(self, message): + def __init__(self, message: str): super().__init__(message) class DistributionGenerator: root3 = math.sqrt(3) - def __init__(self): + def __init__(self) -> None: self.np_gen = np.random.default_rng() - def uniform(self, low, high) -> float: + def uniform(self, low: float, high: float) -> float: return random.uniform(float(low), float(high)) - def uniform_ms(self, mean, sd) -> float: + def uniform_ms(self, mean: float, sd: float) -> float: m = float(mean) h = self.root3 * float(sd) return random.uniform(m - h, m + h) - def normal(self, mean, sd) -> float: + def normal(self, mean: float, sd: float) -> float: return random.normalvariate(float(mean), float(sd)) - def lognormal(self, logmean, logsd) -> float: + def lognormal(self, logmean: float, logsd: float) -> float: return random.lognormvariate(float(logmean), float(logsd)) - def choice(self, a): + def choice(self, a: list[_T]) -> _T: c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c - def zipf_choice(self, a, n=None): + def zipf_choice(self, a: list[_T], n: int | None=None) -> _T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] return c["value"] if type(c) is dict and "value" in c else c - def weighted_choice(self, a: list[dict[str, any]]) -> list[any]: + def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ Choice weighted by the count in the original dataset. :param a: a list of dicts, each with a ``value`` key @@ -110,10 +113,10 @@ def weighted_choice(self, a: list[dict[str, any]]) -> list[any]: c = random.choices(vs, weights=counts)[0] return c - def constant(self, value): + def constant(self, value: _T) -> _T: return value - def multivariate_normal_np(self, cov): + def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: rank = int(cov["rank"]) if rank == 0: return np.empty(shape=(0,)) @@ -127,7 +130,7 @@ def multivariate_normal_np(self, cov): ] return self.np_gen.multivariate_normal(mean, covs) - def _select_group(self, alts: list[dict[str, any]]): + def _select_group(self, alts: list[dict[str, Any]]) -> Any: """ Choose one of the ``alts`` weighted by their ``"count"`` elements. """ @@ -148,14 +151,14 @@ def _select_group(self, alts: list[dict[str, any]]): return alt raise Exception("Internal error: ran out of choices in _select_group") - def _find_constants(self, result: dict[str, any]): + def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: """ Find all keys ``kN``, returning a dictionary of ``N: kNN``. This can be passed into ``merge_with_constants`` as the ``constants_at`` argument. """ - out: dict[int, any] = {} + out: dict[int, Any] = {} for k, v in result.items(): if k.startswith("k") and k[1:].isnumeric(): out[int(k[1:])] = v @@ -171,7 +174,7 @@ def _find_constants(self, result: dict[str, any]): "with_constants_at", } - def multivariate_normal(self, cov): + def multivariate_normal(self, cov: dict[str, Any]) -> list[float]: """ Produce a list of values pulled from a multivariate distribution. @@ -182,9 +185,10 @@ def multivariate_normal(self, cov): ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. :return: list of ``rank`` floating point values """ - return self.multivariate_normal_np(cov).tolist() + out: list[float] = self.multivariate_normal_np(cov).tolist() + return out - def multivariate_lognormal(self, cov): + def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: """ Produce a list of values pulled from a multivariate distribution. @@ -196,16 +200,23 @@ def multivariate_lognormal(self, cov): are all the means and covariants of the logs of the data. :return: list of ``rank`` floating point values """ - return np.exp(self.multivariate_normal_np(cov)).tolist() + out: list[Any] = np.exp(self.multivariate_normal_np(cov)).tolist() + return out - def grouped_multivariate_normal(self, covs): + def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: + """ + Produce a list of values pulled from a set of multivariate distributions. + """ cov = self._select_group(covs) logger.debug("Multivariate normal group selected: %s", cov) constants = self._find_constants(cov) nums = self.multivariate_normal(cov) return list(merge_with_constants(nums, constants)) - def grouped_multivariate_lognormal(self, covs): + def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: + """ + Produce a list of values pulled from a set of multivariate distributions. + """ cov = self._select_group(covs) logger.debug("Multivariate lognormal group selected: %s", cov) constants = self._find_constants(cov) @@ -217,8 +228,8 @@ def _check_generator_name(self, name: str) -> None: raise Exception("%s is not a permitted generator", name) def alternatives( - self, alternative_configs: list[dict[str, any]], counts: list[int] | None - ): + self, alternative_configs: list[dict[str, Any]], counts: list[dict[str, int]] | None + ) -> Any: """ A generator that picks between other generators. @@ -227,6 +238,9 @@ def alternatives( how often to use this alternative; "name" -- which generator for this partition, for example "composite"; "params" -- the parameters for this alternative. + :param counts: A list of weights for each alternative. If None, the + "count" value of each alternative is used. Each count is a dict + with a "count" key. :return: list of values """ if counts is not None: @@ -246,8 +260,8 @@ def alternatives( return getattr(self, name)(**alt["params"]) def with_constants_at( - self, constants_at: list[int], subgen: str, params: dict[str, any] - ): + self, constants_at: dict[int, _T], subgen: str, params: dict[str, _T] + ) -> list[_T]: if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -258,7 +272,7 @@ def with_constants_at( logger.debug("Merging constants %s", constants_at) return list(merge_with_constants(subout, constants_at)) - def truncated_string(self, subgen_fn, params, length): + def truncated_string(self, subgen_fn: Callable[..., list[_T]], params: dict, length: int) -> list[_T]: """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: @@ -340,7 +354,7 @@ def load(self, connection: Connection, base_path: Path = Path(".")) -> None: class ColumnPresence: - def sampled(self, patterns): + def sampled(self, patterns: list[dict[str, Any]]) -> set[Any]: total = 0 for pattern in patterns: total += pattern.get("row_count", 0) diff --git a/datafaker/create.py b/datafaker/create.py index f5228762..e902ec3d 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -53,15 +53,15 @@ def create_db_vocab( metadata: MetaData, meta_dict: dict[str, Any], config: Mapping, - base_path: pathlib.Path | None = pathlib.Path("."), -) -> int: + base_path: pathlib.Path = pathlib.Path("."), +) -> list[str]: """ Load vocabulary tables from files. - arguments: - metadata: The schema of the database - meta_dict: The simple description of the schema from --orm-file - config: The configuration from --config-file + :param metadata: The schema of the database + :param meta_dict: The simple description of the schema from --orm-file + :param config: The configuration from --config-file + :return: List of table names loaded. """ settings = get_settings() dst_dsn: str = settings.dst_dsn or "" @@ -151,6 +151,8 @@ def __init__( self._table_dict: Mapping[str, Table] = table_dict self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict self._dst_conn: Connection = dst_conn + self._table_name: str | None + self._final_values: dict[str, Any] | None = None try: name, self._story = next(self._stories) logger.info("Generating data for story '%s'", name) @@ -165,13 +167,13 @@ def is_ended(self) -> bool: """ return self._table_name is None - def has_table(self, table_name: str): + def has_table(self, table_name: str) -> bool: """ Do we have a row for table table_name? """ return table_name == self._table_name - def table_name(self) -> str: + def table_name(self) -> str | None: """ The name of the current table (or None if no more stories to process) """ @@ -182,10 +184,12 @@ def insert(self) -> None: Perform the insert. Call this after __init__ or next, and after checking that is_ended returns False. """ + if self._table_name is None: + raise StopIteration("StoryIterator.insert after is_ended") table = self._table_dict[self._table_name] if table.name in self._table_generator_dict: table_generator = self._table_generator_dict[table.name] - default_values = table_generator(self._dst_conn, random.random) + default_values = table_generator(self._dst_conn) else: default_values = {} insert_values = {**default_values, **self._provided_values} @@ -273,7 +277,7 @@ def populate( with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): stmt = insert(table).values( - table_generator(dst_conn, random.random) + table_generator(dst_conn) ) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 @@ -286,6 +290,8 @@ def populate( while not story_iterator.is_ended(): story_iterator.insert() t = story_iterator.table_name() + if t is None: + raise Exception("Internal error") row_counts[t] = row_counts.get(t, 0) + 1 story_iterator.next() diff --git a/datafaker/templates/df.py.j2 b/datafaker/templates/df.py.j2 index 28c95827..87c84e4f 100644 --- a/datafaker/templates/df.py.j2 +++ b/datafaker/templates/df.py.j2 @@ -55,7 +55,7 @@ class {{ table_data.class_name }}(TableGenerator): def __init__(self): self.initialized = False - def __call__(self, dst_db_conn, get_random): + def __call__(self, dst_db_conn): if not self.initialized: {% for constraint in table_data.unique_constraints %} query_text = f"SELECT {% diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 8bfb7bc0..51163770 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1623,8 +1623,8 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.assertAlmostEqual( fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) - self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) stmt = select(self.metadata.tables[table2_name]) rows = conn.execute(stmt).fetchall() firsts = Stat() From e63c8bc24d80c6195547c7109e6e9313b38df3a0 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 23:51:48 +0100 Subject: [PATCH 04/44] Some mypy fixes to generator.py --- datafaker/generators.py | 83 +++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/datafaker/generators.py b/datafaker/generators.py index 1ea0a8f5..0a3f8c08 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -5,12 +5,13 @@ import decimal import math import re +import typing from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Callable, Iterable, TypeVar +from typing import Any, Callable, Iterable, Self, TypeVar import mimesis import mimesis.locales @@ -108,18 +109,18 @@ def custom_queries(self) -> dict[str, dict[str, str]]: return {} @abstractmethod - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: """ The kwargs (summary statistics) this generator is instantiated with. """ @abstractmethod - def generate_data(self, count) -> list[any]: + def generate_data(self, count: int) -> list[Any]: """ Generate 'count' random data points for this column. """ - def fit(self, default=None) -> float | None: + def fit(self, default: float = None) -> float | None: """ Return a value representing how well the distribution fits the real source data. @@ -138,7 +139,7 @@ class PredefinedGenerator(Generator): AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') - def _get_src_stats_mentioned(self, val) -> set[str]: + def _get_src_stats_mentioned(self, val: Any) -> set[str]: if not val: return set() if type(val) is str: @@ -159,8 +160,8 @@ def _get_src_stats_mentioned(self, val) -> set[str]: def __init__( self, table_name: str, - generator_object: Mapping[str, any], - config: Mapping[str, any], + generator_object: Mapping[str, Any], + config: Mapping[str, Any], ): """ Initialise a generator from a config.yaml. @@ -226,13 +227,13 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: def custom_queries(self) -> dict[str, dict[str, str]]: return self._custom_queries - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: # Run the queries from nominal_kwargs # ... logger.error("PredefinedGenerator.actual_kwargs not implemented yet") return {} - def generate_data(self, count) -> list[any]: + def generate_data(self, count: int) -> list[Any]: # Call the function if we can. This could be tricky... # ... logger.error("PredefinedGenerator.generate_data not implemented yet") @@ -286,7 +287,9 @@ def __init__( self.stddev = stddev @classmethod - def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): + def make_buckets( + _cls, engine: Engine, table_name: str, column_name: str + ) -> Self | None: """ Construct a Buckets object. @@ -391,7 +394,7 @@ class MimesisGenerator(MimesisGeneratorBase): def __init__( self, function_name: str, - value_fn: Callable[[any], float] | None = None, + value_fn: Callable[[Any], float] | None = None, buckets: Buckets | None = None, ): """ @@ -430,16 +433,16 @@ def __init__( self, function_name: str, length: int, - value_fn: Callable[[any], float] | None = None, + value_fn: Callable[[Any], float] | None = None, buckets: Buckets | None = None, ): self._length = length super().__init__(function_name, value_fn, buckets) - def function_name(self): + def function_name(self) -> str: return "dist_gen.truncated_string" - def name(self): + def name(self) -> str: return f"{self._name} [truncated to {self._length}]" def nominal_kwargs(self): @@ -998,7 +1001,7 @@ def __init__( self._annotation = "sampled and suppressed" @abstractmethod - def get_estimated_counts(counts): + def get_estimated_counts(self, counts): """ The counts that we would expect if this distribution was the correct one. """ @@ -1008,7 +1011,7 @@ def nominal_kwargs(self): "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } - def name(self): + def name(self) -> str: n = super().name() if self._annotation is None: return n @@ -1029,24 +1032,24 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None): + def fit(self, default=None) -> float | None: return default if self._fit is None else self._fit class ZipfChoiceGenerator(ChoiceGenerator): - def get_estimated_counts(self, counts): + def get_estimated_counts(self, counts: list[int]) -> list[int]: return list(zipf_distribution(sum(counts), len(counts))) - def function_name(self): + def function_name(self) -> str: return "dist_gen.zipf_choice" - def generate_data(self, count): + def generate_data(self, count: int) -> list[float]: return [ dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) ] -def uniform_distribution(total, bins): +def uniform_distribution(total, bins: int) -> typing.Generator[int, None, None]: p = total // bins n = total % bins for _ in range(0, n): @@ -1108,7 +1111,7 @@ def get_generators(self, columns: list[Column], engine: Engine): values = [] # The values found counts = [] # The number or each value cvs: list[ - dict[str, any] + dict[str, Any] ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f @@ -1138,14 +1141,14 @@ def get_generators(self, columns: list[Column], engine: Engine): values = [] # All values found counts = [] # The number or each value cvs: list[ - dict[str, any] + dict[str, Any] ] = [] # list of dicts with keys "v" and "count" values_not_suppressed = ( [] ) # All values found more than SUPPRESS_COUNT times counts_not_suppressed = [] # The number for each value not suppressed cvs_not_suppressed: list[ - dict[str, any] + dict[str, Any] ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f @@ -1229,10 +1232,10 @@ def function_name(self) -> str: def nominal_kwargs(self) -> dict[str, str]: return {"value": self.repr} - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: return {"value": self.value} - def generate_data(self, count) -> list[any]: + def generate_data(self, count) -> list[Any]: return [self.value for _ in range(count)] @@ -1289,13 +1292,13 @@ def custom_queries(self): } } - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: """ The kwargs (summary statistics) this generator is instantiated with. """ return {"cov": self._covariates} - def generate_data(self, count) -> list[any]: + def generate_data(self, count) -> list[Any]: """ Generate 'count' random data points for this column. """ @@ -1372,7 +1375,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1443,9 +1446,9 @@ class RowPartition: # list of included column values (so once the generator has # been run and the included_choice values have been # added): {index: value} - constant_outputs: dict[int, any] + constant_outputs: dict[int, Any] # The actual covariates from the source database - covariates: dict[str, float] + covariates: list[dict[str, float]] def comment(self) -> str: caveat = "" @@ -1506,10 +1509,10 @@ def __init__( else: self._name = f"null-partitioned {function_name}" - def name(self): + def name(self) -> str: return self._name - def function_name(self): + def function_name(self) -> str: return "dist_gen.alternatives" def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition): @@ -1599,7 +1602,7 @@ def _actual_kwargs_with_combinations(self, partition: RowPartition): }, } - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: """ The kwargs (summary statistics) this generator is instantiated with. """ @@ -1611,7 +1614,7 @@ def actual_kwargs(self) -> dict[str, any]: "counts": self._partition_counts, } - def generate_data(self, count) -> list[any]: + def generate_data(self, count) -> list[Any]: """ Generate 'count' random data points for this column. """ @@ -1630,7 +1633,7 @@ def is_numeric(col: Column) -> bool: T = TypeVar("T") -def powerset(input: Iterable[T]) -> Iterable[Iterable[T]]: +def powerset(input: list[T]) -> Iterable[Iterable[T]]: """Returns a list of all sublists of""" return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) @@ -1736,7 +1739,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) @@ -1784,7 +1787,7 @@ def get_generators(self, columns: list[Column], engine: Engine): partition_def.nones, {}, ) - gens = [] + gens: list[Generator] = [] try: with engine.connect() as connection: partition_query_max = self.get_partition_count_query( @@ -1834,7 +1837,7 @@ def _execute_partition_queries( self, connection: Connection, partitions: dict[int, RowPartition], - ): + ) -> bool: """ Execute the query in each partition, filling in the covariates. :return: True if all the partitions work, False if any of them fail. @@ -1864,7 +1867,7 @@ def query_var(self, column: str) -> str: @lru_cache(1) -def everything_factory(): +def everything_factory() -> GeneratorFactory: return MultiGeneratorFactory( [ MimesisStringGeneratorFactory(), From f28d7b9bd752f7e48d87da14c37190fb138a1929 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Sun, 5 Oct 2025 23:38:28 +0100 Subject: [PATCH 05/44] Fixed variances in tests --- tests/examples/eav.sql | 4 ++-- tests/test_interactive.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/examples/eav.sql b/tests/examples/eav.sql index fe5879bd..c5af320b 100644 --- a/tests/examples/eav.sql +++ b/tests/examples/eav.sql @@ -22,8 +22,8 @@ INSERT INTO public.measurement_type VALUES (5, 'matter'); CREATE TABLE public.measurement ( id INTEGER NOT NULL, type INTEGER NOT NULL, - first_value INTEGER, - second_value INTEGER, + first_value FLOAT, + second_value FLOAT, third_value TEXT ); diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 51163770..284e04d0 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1490,32 +1490,32 @@ def test_create_with_null_partitioned_grouped_multivariate(self): two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 ) self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) - self.assertAlmostEqual(two.x_var(), 0.21, delta=0.4) + self.assertAlmostEqual(two.x_var(), 0.315, delta=0.18) self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(two.y_var(), 0.07, delta=0.1) - self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) + self.assertAlmostEqual(two.y_var(), 0.105, delta=0.06) + self.assertAlmostEqual(two.covar(), 0.105, delta=0.07) # type 3 self.assertAlmostEqual( three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(two.covar(), -0.5, delta=0.5) + self.assertAlmostEqual(three.covar(), -2.085, delta=1.1) # type 4 self.assertAlmostEqual( four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) + self.assertAlmostEqual(four.covar(), 3.33, delta=1) # type 5/fish self.assertAlmostEqual( fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) + self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.6) # type 5/fowl self.assertAlmostEqual( fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): """Test EAV for all columns with sampled and suppressed generation.""" @@ -1618,13 +1618,13 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) + self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.5) # type 5/fowl self.assertAlmostEqual( fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) stmt = select(self.metadata.tables[table2_name]) rows = conn.execute(stmt).fetchall() firsts = Stat() From 20ff9eddebc84f6df5f23de941471e3fb7e06107 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 6 Oct 2025 16:59:55 +0100 Subject: [PATCH 06/44] Mypy fixes in generators.py --- datafaker/generators.py | 379 +++++++++++++++++++++++----------------- 1 file changed, 216 insertions(+), 163 deletions(-) diff --git a/datafaker/generators.py b/datafaker/generators.py index 0a3f8c08..dc1fdc0c 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -11,17 +11,20 @@ from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Any, Callable, Iterable, Self, TypeVar +from typing import Any, Callable, Iterable, Sequence, TypeVar, Union +from typing_extensions import Self import mimesis import mimesis.locales import sqlalchemy -from sqlalchemy import Column, Connection, Engine, RowMapping, Sequence, text -from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time +from sqlalchemy import Column, Connection, CursorResult, Engine, RowMapping, text +from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time, TypeEngine from datafaker.base import DistributionGenerator from datafaker.utils import logger +numeric = Union[int, float] + # How many distinct values can we have before we consider a # choice distribution to be infeasible? MAXIMUM_CHOICES = 500 @@ -120,7 +123,7 @@ def generate_data(self, count: int) -> list[Any]: Generate 'count' random data points for this column. """ - def fit(self, default: float = None) -> float | None: + def fit(self, default: float | None=None) -> float | None: """ Return a value representing how well the distribution fits the real source data. @@ -179,7 +182,7 @@ def __init__( self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) # Need to deal with this somehow (or remove it from the schema) self._argn: list[str] = generator_object.get("args", []) - self._select_aggregate_clauses = {} + self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} self._custom_queries = {} for sstat in config.get("src-stats", []): name: str = sstat["name"] @@ -246,7 +249,7 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -278,7 +281,7 @@ def __init__( ) ) ) - self.buckets = [0] * 10 + self.buckets: Sequence[int] = [0] * 10 for rb in raw_buckets: if rb.b is not None: bucket = min(9, max(0, int(rb.b) + 1)) @@ -288,7 +291,7 @@ def __init__( @classmethod def make_buckets( - _cls, engine: Engine, table_name: str, column_name: str + cls, engine: Engine, table_name: str, column_name: str ) -> Self | None: """ Construct a Buckets object. @@ -308,23 +311,23 @@ def make_buckets( ) ) ).first() - if result is None or result.stddev is None or result.count < 2: + if result is None or result.stddev is None or getattr(result, "count") < 2: return None try: - buckets = Buckets( + buckets = cls( engine, table_name, column_name, result.mean, result.stddev, - result.count, + getattr(result, "count"), ) except sqlalchemy.exc.DatabaseError as exc: logger.debug("Failed to instantiate Buckets object: %s", exc) return None return buckets - def fit_from_counts(self, bucket_counts: list[float]) -> float: + def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: """ Figure out the fit from bucket counts from the generator distribution. """ @@ -350,7 +353,7 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: return [ generator for factory in self.factories @@ -383,10 +386,10 @@ def __init__( self._name = "generic." + function_name self._generator_function = f - def function_name(self): + def function_name(self) -> str: return self._name - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [self._generator_function() for _ in range(count)] @@ -415,16 +418,16 @@ def __init__( samples = [value_fn(s) for s in samples] self._fit = buckets.fit_from_values(samples) - def function_name(self): + def function_name(self) -> str: return self._name - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return {} - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return {} - def fit(self, default=None): + def fit(self, default: float | None=None) -> float | None: return default if self._fit is None else self._fit @@ -445,21 +448,21 @@ def function_name(self) -> str: def name(self) -> str: return f"{self._name} [truncated to {self._length}]" - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "subgen_fn": self._name, "params": {}, "length": self._length, } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "subgen_fn": self._name, "params": {}, "length": self._length, } - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [self._generator_function()[: self._length] for _ in range(count)] @@ -472,7 +475,7 @@ def __init__( max_year: str, start: int, end: int, - ): + ) -> None: """ :param column: The column to generate into :param function_name: The name of the mimesis function @@ -489,7 +492,7 @@ def __init__( self._end = end @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str): + def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> list[Generator]: extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -512,13 +515,13 @@ def make_singleton(_cls, column: Column, engine: Engine, function_name: str): ) ] - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "start": self._start, "end": self._end, @@ -536,14 +539,14 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [ self._generator_function(start=self._start, end=self._end) for _ in range(count) ] -def get_column_type(column: Column): +def get_column_type(column: Column) -> TypeEngine: try: return column.type.as_generic() except NotImplementedError: @@ -589,7 +592,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -632,7 +635,7 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -653,7 +656,7 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -668,7 +671,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -685,7 +688,7 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -700,7 +703,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -710,26 +713,28 @@ def get_generators(self, columns: list[Column], engine: Engine): return [MimesisGenerator("person.weight")] -def fit_from_buckets(xs: list[float], ys: list[float]): +def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) return sum_diff_squared / (count * count) class ContinuousDistributionGenerator(Generator): + expected_buckets: Sequence[numeric] = [] + def __init__(self, table_name: str, column_name: str, buckets: Buckets): super().__init__() self.table_name = table_name self.column_name = column_name self.buckets = buckets - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: if self.buckets is None: return {} return { @@ -751,7 +756,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None): + def fit(self, default: float | None=None) -> float | None: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -771,10 +776,10 @@ class GaussianGenerator(ContinuousDistributionGenerator): 0.0227, ] - def function_name(self): + def function_name(self) -> str: return "dist_gen.normal" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [ dist_gen.normal(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -795,10 +800,10 @@ class UniformGenerator(ContinuousDistributionGenerator): 0, ] - def function_name(self): + def function_name(self) -> str: return "dist_gen.uniform_ms" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [ dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -822,7 +827,7 @@ def _get_generators_from_buckets( UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -869,19 +874,19 @@ def __init__( self.logmean = logmean self.logstddev = logstddev - def function_name(self): + def function_name(self) -> str: return "dist_gen.lognormal" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "logmean": self.logmean, "logsd": self.logstddev, @@ -901,7 +906,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None): + def fit(self, default: float | None=None) -> float | None: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -941,7 +946,15 @@ def _get_generators_from_buckets( ] -def zipf_distribution(total, bins): +def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Get a zipf distribution for a certain number of items distributed + in a certain number of bins. + :param total: The total number of items to be distributed. + :param bins: The total number of bins to distribute the items into. + :return: A generator of the number of items in each bin, from the + largest to the smallest. + """ basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) bd_remaining = sum(basic_dist) for b in basic_dist: @@ -960,13 +973,13 @@ class ChoiceGenerator(Generator): def __init__( self, - table_name, - column_name, - values, - counts, - sample_count=None, - suppress_count=0, - ): + table_name: str, + column_name : str, + values: list[Any], + counts: list[int], + sample_count: int | None=None, + suppress_count: int=0, + ) -> None: super().__init__() self.table_name = table_name self.column_name = column_name @@ -1001,12 +1014,12 @@ def __init__( self._annotation = "sampled and suppressed" @abstractmethod - def get_estimated_counts(self, counts): + def get_estimated_counts(self, counts: list[int]) -> list[int]: """ The counts that we would expect if this distribution was the correct one. """ - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } @@ -1017,7 +1030,7 @@ def name(self) -> str: return n return f"{n} [{self._annotation}]" - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "a": self.values, } @@ -1032,7 +1045,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None) -> float | None: + def fit(self, default: float | None=None) -> float | None: return default if self._fit is None else self._fit @@ -1049,7 +1062,12 @@ def generate_data(self, count: int) -> list[float]: ] -def uniform_distribution(total, bins: int) -> typing.Generator[int, None, None]: +def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + A generator putting ``total`` items uniformly into ``bins`` bins. + If they don't fit exactly evenly, the earlier bins will have one more + item than the later bins so the total is as required. + """ p = total // bins n = total % bins for _ in range(0, n): @@ -1059,29 +1077,86 @@ def uniform_distribution(total, bins: int) -> typing.Generator[int, None, None]: class UniformChoiceGenerator(ChoiceGenerator): - def get_estimated_counts(self, counts): + """ + A generator producing values, each roughly as frequently as each other. + """ + def get_estimated_counts(self, counts: list[int]) -> list[int]: return list(uniform_distribution(sum(counts), len(counts))) - def function_name(self): + def function_name(self) -> str: return "dist_gen.choice" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [dist_gen.choice(self.values) for _ in range(count)] class WeightedChoiceGenerator(ChoiceGenerator): STORE_COUNTS = True - def get_estimated_counts(self, counts): + def get_estimated_counts(self, counts: list[int]) -> list[int]: return counts - def function_name(self): + def function_name(self) -> str: return "dist_gen.weighted_choice" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [dist_gen.weighted_choice(self.values) for _ in range(count)] +class ValueGatherer: + """ + Gathers values from a query of values and counts. + + The query must return columns ``v`` for a value and ``f`` for the + count of how many of those values there are. + These values will be gathered into a number of properties: + ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts + in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and + ``count`` giving these values and counts. ``counts_not_suppressed``, + ``values_not_suppressed`` and ``cvs_not_suppressed`` are the + equivalents with the counts less than or equal to ``suppress_count`` + removed. + + :param suppress_count: value with a count of this or fewer will be excluded + from the suppressed values. + """ + def __init__(self, results: CursorResult, suppress_count: int=0) -> None: + values = [] # All values found + counts = [] # The number or each value + cvs: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = ( + [] + ) # All values found more than SUPPRESS_COUNT times + counts_not_suppressed = [] # The number for each value not suppressed + cvs_not_suppressed: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + for result in results: + c = result.f + if c != 0: + counts.append(c) + v = result.v + if type(v) is decimal.Decimal: + v = float(v) + values.append(v) + cvs.append({"value": v, "count": c}) + if suppress_count < c: + counts_not_suppressed.append(c) + v = result.v + if type(v) is decimal.Decimal: + v = float(v) + values_not_suppressed.append(v) + cvs_not_suppressed.append({"value": v, "count": c}) + self.values = values + self.counts = counts + self.cvs = cvs + self.values_not_suppressed = values_not_suppressed + self.counts_not_suppressed = counts_not_suppressed + self.cvs_not_suppressed = cvs_not_suppressed + + class ChoiceGeneratorFactory(GeneratorFactory): """ All generators that want an average and standard deviation. @@ -1090,7 +1165,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1108,25 +1183,12 @@ def get_generators(self, columns: list[Column], engine: Engine): ) ) if results is not None and results.rowcount <= MAXIMUM_CHOICES: - values = [] # The values found - counts = [] # The number or each value - cvs: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if counts: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, values, counts), - UniformChoiceGenerator(table_name, column_name, values, counts), - WeightedChoiceGenerator(table_name, column_name, cvs, counts), + ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), + UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), + WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), ] results = connection.execute( text( @@ -1138,81 +1200,59 @@ def get_generators(self, columns: list[Column], engine: Engine): ) ) if results is not None: - values = [] # All values found - counts = [] # The number or each value - cvs: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = ( - [] - ) # All values found more than SUPPRESS_COUNT times - counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if self.SUPPRESS_COUNT < c: - counts_not_suppressed.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values_not_suppressed.append(v) - cvs_not_suppressed.append({"value": v, "count": c}) - if counts: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), + UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), + WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), + ] generators += [ ZipfChoiceGenerator( table_name, column_name, - values, - counts, + vg.values, + vg.counts, sample_count=self.SAMPLE_COUNT, ), UniformChoiceGenerator( table_name, column_name, - values, - counts, + vg.values, + vg.counts, sample_count=self.SAMPLE_COUNT, ), WeightedChoiceGenerator( table_name, column_name, - cvs, - counts, + vg.cvs, + vg.counts, sample_count=self.SAMPLE_COUNT, ), ] - if counts_not_suppressed: + if vg.counts_not_suppressed: generators += [ ZipfChoiceGenerator( table_name, column_name, - values_not_suppressed, - counts_not_suppressed, + vg.values_not_suppressed, + vg.counts_not_suppressed, sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), UniformChoiceGenerator( table_name, column_name, - values_not_suppressed, - counts_not_suppressed, + vg.values_not_suppressed, + vg.counts_not_suppressed, sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), WeightedChoiceGenerator( table_name=table_name, column_name=column_name, - values=cvs_not_suppressed, - counts=counts, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), @@ -1221,7 +1261,7 @@ def get_generators(self, columns: list[Column], engine: Engine): class ConstantGenerator(Generator): - def __init__(self, value): + def __init__(self, value: Any) -> None: super().__init__() self.value = value self.repr = repr(value) @@ -1235,7 +1275,7 @@ def nominal_kwargs(self) -> dict[str, str]: def actual_kwargs(self) -> dict[str, Any]: return {"value": self.value} - def generate_data(self, count) -> list[Any]: + def generate_data(self, count: int) -> list[Any]: return [self.value for _ in range(count)] @@ -1244,7 +1284,7 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1261,29 +1301,32 @@ def get_generators(self, columns: list[Column], engine: Engine): class MultivariateNormalGenerator(Generator): + """ + Generator of multiple values drawn from a multivariate normal distribution. + """ def __init__( self, - table_name: list[str], + table_name: str, column_names: list[str], query: str, - covariates: dict[str, float], + covariates: RowMapping, function_name: str, - ): + ) -> None: self._table = table_name self._columns = column_names self._query = query self._covariates = covariates self._function_name = function_name - def function_name(self): + def function_name(self) -> str: return "dist_gen." + self._function_name - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', } - def custom_queries(self): + def custom_queries(self) -> dict[str, Any]: cols = ", ".join(self._columns) return { f"auto__cov__{self._table}": { @@ -1298,7 +1341,7 @@ def actual_kwargs(self) -> dict[str, Any]: """ return {"cov": self._covariates} - def generate_data(self, count) -> list[Any]: + def generate_data(self, count: int) -> list[Any]: """ Generate 'count' random data points for this column. """ @@ -1307,7 +1350,7 @@ def generate_data(self, count) -> list[Any]: for _ in range(count) ] - def fit(self, default=None) -> float | None: + def fit(self, default: float | None=None) -> float | None: return default @@ -1375,7 +1418,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1417,17 +1460,23 @@ def query_var(self, column: str) -> str: return f"LN({column})" -def text_list(items: list[str]) -> str: +def text_list(items: Iterable[str]) -> str: """ Concatenate the items with commas and one "and". """ - if not hasattr(items, "__getitem__"): - items = list(items) - if len(items) == 0: + item_i = iter(items) + try: + last_item = next(item_i) + except StopIteration: return "" - if len(items) == 1: - return items[0] - return ", ".join(items[:-1]) + " and " + items[-1] + try: + so_far = next(item_i) + except StopIteration: + return last_item + for item in item_i: + so_far += ", " + last_item + last_item = item + return so_far + " and " + last_item @dataclass @@ -1448,7 +1497,7 @@ class RowPartition: # added): {index: value} constant_outputs: dict[int, Any] # The actual covariates from the source database - covariates: list[dict[str, float]] + covariates: Sequence[RowMapping] def comment(self) -> str: caveat = "" @@ -1495,7 +1544,7 @@ def __init__( function_name: str = "grouped_multivariate_lognormal", name_suffix: str | None = None, partition_count_query: str | None = None, - partition_counts: Sequence[RowMapping] | None = None, + partition_counts: Iterable[RowMapping] = [], partition_count_comment: str | None = None, ): self._query_name = query_name @@ -1515,7 +1564,7 @@ def name(self) -> str: def function_name(self) -> str: return "dist_gen.alternatives" - def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition): + def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) -> dict[str, Any]: count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' if not partition.included_numeric and not partition.included_choice: return { @@ -1542,12 +1591,10 @@ def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) }, } - def _count_query_name(self): - if self._partition_count_query: - return f"auto__cov__{self._query_name}__counts" - return None + def _count_query_name(self) -> str: + return f"auto__cov__{self._query_name}__counts" - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "alternative_configs": [ self._nominal_kwargs_with_combinations(index, self._partitions[index]) @@ -1556,7 +1603,7 @@ def nominal_kwargs(self): "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', } - def custom_queries(self): + def custom_queries(self) -> dict[str, Any]: partitions = { f"auto__cov__{self._query_name}__alt_{index}": { "comment": partition.comment(), @@ -1574,7 +1621,7 @@ def custom_queries(self): **partitions, } - def _actual_kwargs_with_combinations(self, partition: RowPartition): + def _actual_kwargs_with_combinations(self, partition: RowPartition) -> dict[str, Any]: count = sum(row["count"] for row in partition.covariates) if not partition.included_numeric and not partition.included_choice: return { @@ -1614,14 +1661,14 @@ def actual_kwargs(self) -> dict[str, Any]: "counts": self._partition_counts, } - def generate_data(self, count) -> list[Any]: + def generate_data(self, count: int) -> list[Any]: """ Generate 'count' random data points for this column. """ kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] - def fit(self, default=None) -> float | None: + def fit(self, default: float | None=None) -> float | None: return default @@ -1690,6 +1737,12 @@ def __init__( class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 + EMPTY_RESULT = [RowMapping( + parent=sqlalchemy.engine.result.ResultMetaData(), + processors=None, + key_to_index={"count": 0}, + data=(0,) + )] def function_name(self) -> str: return "grouped_multivariate_normal" @@ -1739,7 +1792,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) @@ -1767,7 +1820,7 @@ def get_generators(self, columns: list[Column], engine: Engine) -> list[Generato partition_def.included_choice, partition_def.excluded, partition_def.nones, - {}, + [], ) query = self.query( table=table, @@ -1785,7 +1838,7 @@ def get_generators(self, columns: list[Column], engine: Engine) -> list[Generato partition_def.included_choice, partition_def.excluded, partition_def.nones, - {}, + [], ) gens: list[Generator] = [] try: @@ -1846,7 +1899,7 @@ def _execute_partition_queries( for rp in partitions.values(): rp.covariates = connection.execute(text(rp.query)).mappings().fetchall() if not rp.covariates or rp.covariates[0]["count"] is None: - rp.covariates = [{"count": 0}] + rp.covariates = self.EMPTY_RESULT else: found_nonzero = True return found_nonzero From 49954de69cea1fb005369413631919320e1f9e9c Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 6 Oct 2025 18:59:32 +0100 Subject: [PATCH 07/44] mypy fixes in interactive.py --- datafaker/base.py | 20 ++- datafaker/generators.py | 40 ++--- datafaker/interactive.py | 336 +++++++++++++++++++++++++-------------- datafaker/utils.py | 4 +- 4 files changed, 250 insertions(+), 150 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 0a3cf41f..acdc9b95 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Generator, TypeVar +from typing import Any, Callable, Generator import numpy as np import yaml @@ -18,22 +18,20 @@ from datafaker.utils import ( MAKE_VOCAB_PROGRESS_REPORT_EVERY, + T, logger, stream_yaml, table_row_count, ) -_T = TypeVar("_T") - - @functools.cache def zipf_weights(size: int) -> list[float]: total = sum(map(lambda n: 1 / n, range(1, size + 1))) return [1 / (n * total) for n in range(1, size + 1)] -def merge_with_constants(xs: list[_T], constants_at: dict[int, _T]) -> Generator[_T, None, None]: +def merge_with_constants(xs: list[T], constants_at: dict[int, T]) -> Generator[T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. :param constants_at: A map of indices to objects that must be placed at @@ -86,11 +84,11 @@ def normal(self, mean: float, sd: float) -> float: def lognormal(self, logmean: float, logsd: float) -> float: return random.lognormvariate(float(logmean), float(logsd)) - def choice(self, a: list[_T]) -> _T: + def choice(self, a: list[T]) -> T: c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c - def zipf_choice(self, a: list[_T], n: int | None=None) -> _T: + def zipf_choice(self, a: list[T], n: int | None=None) -> T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] @@ -113,7 +111,7 @@ def weighted_choice(self, a: list[dict[str, Any]]) -> Any: c = random.choices(vs, weights=counts)[0] return c - def constant(self, value: _T) -> _T: + def constant(self, value: T) -> T: return value def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: @@ -260,8 +258,8 @@ def alternatives( return getattr(self, name)(**alt["params"]) def with_constants_at( - self, constants_at: dict[int, _T], subgen: str, params: dict[str, _T] - ) -> list[_T]: + self, constants_at: dict[int, T], subgen: str, params: dict[str, T] + ) -> list[T]: if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -272,7 +270,7 @@ def with_constants_at( logger.debug("Merging constants %s", constants_at) return list(merge_with_constants(subout, constants_at)) - def truncated_string(self, subgen_fn: Callable[..., list[_T]], params: dict, length: int) -> list[_T]: + def truncated_string(self, subgen_fn: Callable[..., list[T]], params: dict, length: int) -> list[T]: """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: diff --git a/datafaker/generators.py b/datafaker/generators.py index dc1fdc0c..57a3c640 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -123,7 +123,7 @@ def generate_data(self, count: int) -> list[Any]: Generate 'count' random data points for this column. """ - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: """ Return a value representing how well the distribution fits the real source data. @@ -249,7 +249,7 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -353,7 +353,7 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: return [ generator for factory in self.factories @@ -427,7 +427,7 @@ def nominal_kwargs(self) -> dict[str, Any]: def actual_kwargs(self) -> dict[str, Any]: return {} - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default if self._fit is None else self._fit @@ -592,7 +592,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -635,7 +635,7 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -656,7 +656,7 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -671,7 +671,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -688,7 +688,7 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -703,7 +703,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -756,7 +756,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -827,7 +827,7 @@ def _get_generators_from_buckets( UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -906,7 +906,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -1045,7 +1045,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default if self._fit is None else self._fit @@ -1165,7 +1165,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1284,7 +1284,7 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1350,7 +1350,7 @@ def generate_data(self, count: int) -> list[Any]: for _ in range(count) ] - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default @@ -1418,7 +1418,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1668,7 +1668,7 @@ def generate_data(self, count: int) -> list[Any]: kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default @@ -1792,7 +1792,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 4f5e8c55..ee9b9409 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -3,10 +3,12 @@ import functools import re from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Collection, Mapping from dataclasses import dataclass from enum import Enum from pathlib import Path +from typing import Any, Callable, Iterable, cast +from typing_extensions import Self import sqlalchemy from prettytable import PrettyTable @@ -108,12 +110,12 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: ... def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] ): super().__init__() - self.config = config + self.config: Mapping[str, Any] = config self.metadata = metadata - self.table_entries: list[TableEntry] = [] + self._table_entries: Collection[TableEntry] = [] tables_config: Mapping = config.get("tables", {}) if type(tables_config) is not dict: tables_config = {} @@ -127,16 +129,16 @@ def __init__( self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.engine.dispose() - def print(self, text: str, *args, **kwargs): + def print(self, text: str, *args, **kwargs) -> None: print(text.format(*args, **kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]): + def print_table(self, headings: list[str], rows: list[list[str]]) -> None: output = PrettyTable() output.field_names = headings for row in rows: @@ -159,7 +161,7 @@ def ask_save(self): return ask.result def set_table_index(self, index) -> bool: - if 0 <= index and index < len(self.table_entries): + if 0 <= index and index < len(self._table_entries): self.table_index = index self.set_prompt() return True @@ -172,7 +174,7 @@ def next_table(self, report="No more tables"): return True def table_name(self): - return self.table_entries[self.table_index].name + return self._table_entries[self.table_index].name def table_metadata(self) -> Table: return self.metadata.tables[self.table_name()] @@ -195,21 +197,21 @@ def report_columns(self): ], ) - def get_table_config(self, table_name: str) -> dict[str, any]: + def get_table_config(self, table_name: str) -> dict[str, Any]: ts = self.config.get("tables", None) if type(ts) is not dict: return {} t = ts.get(table_name) return t if type(t) is dict else {} - def set_table_config(self, table_name: str, config: dict[str, any]): + def set_table_config(self, table_name: str, config: dict[str, Any]): ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} return ts[table_name] = config - def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: + def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: src_stats = self.config.get("src-stats", []) new_src_stats = [] for stat in src_stats: @@ -218,7 +220,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: self.config["src-stats"] = new_src_stats return new_src_stats - def get_nonnull_columns(self, table_name: str): + def get_nonnull_columns(self, table_name: str) -> list[str]: metadata_table = self.metadata.tables[table_name] return [ str(name) @@ -230,21 +232,21 @@ def find_entry_index_by_table_name(self, table_name) -> int | None: return next( ( i - for i, entry in enumerate(self.table_entries) + for i, entry in enumerate(self._table_entries) if entry.name == table_name ), None, ) def find_entry_by_table_name(self, table_name) -> TableEntry | None: - for e in self.table_entries: + for e in self._table_entries: if e.name == table_name: return e return None def do_counts(self, _arg): "Report the column names with the counts of nulls in them" - if len(self.table_entries) <= self.table_index: + if len(self._table_entries) <= self.table_index: return table_name = self.table_name() nonnull_columns = self.get_nonnull_columns(table_name) @@ -289,14 +291,14 @@ def do_select(self, arg): rows = [row._tuple() for row in result.fetchmany(MAX_SELECT_ROWS)] self.print_table(fields, rows) - def do_peek(self, arg: str): + def do_peek(self, arg: str) -> None: """ Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. Use 'peek' to see a sample of the current column(s). Rows that are enitrely null are suppressed. """ MAX_PEEK_ROWS = 25 - if len(self.table_entries) <= self.table_index: + if len(self._table_entries) <= self.table_index: return table_name = self.table_name() col_names = arg.split() @@ -319,8 +321,8 @@ def do_peek(self, arg: str): rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] self.print_table(list(result.keys()), rows) - def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int): - if len(self.table_entries) <= self.table_index: + def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + if len(self._table_entries) <= self.table_index: return [] return [ col for col in self.table_metadata().columns.keys() if col.startswith(text) @@ -369,18 +371,22 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: def __init__( self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping - ): + ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() - def set_prompt(self): + @property + def table_entries(self) -> list[TableCmdTableEntry]: + return cast(TableCmdTableEntry, self._table_entries) + + def set_prompt(self) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) else: self.prompt = "(table) " - def set_type(self, t_type: TableType): + def set_type(self, t_type: TableType) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type @@ -428,7 +434,6 @@ def _sanity_check_failures(self) -> list[tuple[str, str, str]]: """Find tables that reference each other that should not given their types.""" failures = [] for from_entry in self.table_entries: - from_entry: TableCmdTableEntry from_t = from_entry.new_type if from_t == TableType.VOCABULARY: referenced = self._get_referenced_tables(from_entry.name) @@ -451,7 +456,6 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: """Find tables that reference each other that might cause problems given their types.""" warnings = [] for from_entry in self.table_entries: - from_entry: TableCmdTableEntry from_t = from_entry.new_type if from_t in {TableType.GENERATE, TableType.PRIVATE}: referenced = self._get_referenced_tables(from_entry.name) @@ -470,7 +474,7 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: ) return warnings - def do_quit(self, _arg): + def do_quit(self, _arg: str) -> bool: "Check the updates, save them if desired and quit the configurer." count = 0 for entry in self.table_entries: @@ -502,7 +506,7 @@ def do_quit(self, _arg): return True return False - def do_tables(self, _arg): + def do_tables(self, _arg: str) -> None: "list the tables with their types" for entry in self.table_entries: old = entry.old_type @@ -510,7 +514,7 @@ def do_tables(self, _arg): becomes = " " if old == new else "->" + TYPE_LETTER[new] self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) - def do_next(self, arg): + def do_next(self, arg: str) -> None: "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: # Find the index of the table called _arg, if any @@ -522,51 +526,51 @@ def do_next(self, arg): return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text, line, begidx, endidx): + def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] - def do_previous(self, _arg): + def do_previous(self, _arg: str) -> None: "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) - def do_ignore(self, _arg): + def do_ignore(self, _arg: str) -> None: "Set the current table as ignored, and go to the next table" self.set_type(TableType.IGNORE) self.print("Table {} set as ignored", self.table_name()) self.next_table() - def do_vocabulary(self, _arg): + def do_vocabulary(self, _arg: str) -> None: "Set the current table as a vocabulary table, and go to the next table" self.set_type(TableType.VOCABULARY) self.print("Table {} set to be a vocabulary table", self.table_name()) self.next_table() - def do_private(self, _arg): + def do_private(self, _arg: str) -> None: "Set the current table as a primary private table (such as the table of patients)" self.set_type(TableType.PRIVATE) self.print("Table {} set to be a primary private table", self.table_name()) self.next_table() - def do_generate(self, _arg): + def do_generate(self, _arg: str) -> None: "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() - def do_empty(self, _arg): + def do_empty(self, _arg: str) -> None: "Set the current table as empty; no generators will be run for it" self.set_type(TableType.EMPTY) self.print("Table {} empty", self.table_name()) self.next_table() - def do_columns(self, _arg): + def do_columns(self, _arg: str) -> None: "Report the column names and metadata" self.report_columns() - def do_data(self, arg: str): + def do_data(self, arg: str) -> None: """ Report some data. 'data' = report a random ten lines, @@ -606,14 +610,14 @@ def do_data(self, arg: str): number = 48 self.print_column_data(column, number, min_length) - def complete_data(self, text, line, begidx, endidx): + def complete_data(self, text: str, line: str, begidx: int, _endidx: int) -> list[str]: previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] table_metadata = self.table_metadata() return [k for k in table_metadata.columns.keys() if k.startswith(text)] - def print_column_data(self, column: str, count: int, min_length: int): + def print_column_data(self, column: str, count: int, min_length: int) -> None: where = f"WHERE {column} IS NOT NULL" if 0 < min_length: where = "WHERE LENGTH({column}) >= {len}".format( @@ -633,7 +637,7 @@ def print_column_data(self, column: str, count: int, min_length: int): ) self.columnize([str(x[0]) for x in result.all()]) - def print_row_data(self, count: int): + def print_row_data(self, count: int) -> None: with self.engine.connect() as connection: result = connection.execute( text( @@ -651,7 +655,7 @@ def print_row_data(self, count: int): def update_config_tables( src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping -): +) -> Mapping[str, Any]: with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() return tc.config @@ -671,7 +675,7 @@ class MissingnessType: columns: list[str] @classmethod - def sampled_query(cls, table, count, column_names) -> str: + def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) column_is_nulls = ", ".join( ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] @@ -713,9 +717,9 @@ def find_missingness_query( for src_stat in self.config["src-stats"]: if src_stat.get("name") == key: return (src_stat.get("query", None), src_stat.get("comment", None)) - return None + return None - def make_table_entry(self, name: str, table: Mapping) -> TableEntry: + def make_table_entry(self, name: str, table: Mapping) -> TableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -737,7 +741,7 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: elif len(mgs) == 1: mg = mgs[0] mg_name = mg.get("name", None) - if mg_name is not None: + if type(mg_name) is str: query_comment = self.find_missingness_query(mg) if query_comment is not None: (query, comment) = query_comment @@ -758,10 +762,24 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: def __init__( self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping ): + """ + Initialise a MissingnessCmd. + :param src_dsn: connection string for the source database. + :param src_schema: schema name for the source database. + :param metadata: SQLAlchemy metadata for the source database. + :param config: Configuration from the ``config.yaml`` file. + """ super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() - def set_prompt(self): + @property + def table_entries(self) -> list[MissingnessCmdTableEntry]: + return cast(MissingnessCmdTableEntry, self._table_entries) + + def set_prompt(self) -> None: + """ + Sets the prompt according to the current table and missingness. + """ if self.table_index < len(self.table_entries): entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] nt = entry.new_type @@ -772,7 +790,7 @@ def set_prompt(self): else: self.prompt = "(missingness) " - def set_type(self, t_type: TableType): + def set_type(self, t_type: TableType) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type @@ -780,7 +798,6 @@ def set_type(self, t_type: TableType): def _copy_entries(self) -> None: src_stats = self._remove_prefix_src_stats("missing_auto__") for entry in self.table_entries: - entry: MissingnessCmdTableEntry table = self.get_table_config(entry.name) if entry.new_type is None or entry.new_type.name == "none": table.pop("missingness_generators", None) @@ -808,7 +825,7 @@ def _copy_entries(self) -> None: ) self.set_table_config(entry.name, table) - def do_quit(self, _arg): + def do_quit(self, _arg: str) -> bool: "Check the updates, save them if desired and quit the configurer." count = 0 for entry in self.table_entries: @@ -843,7 +860,7 @@ def do_quit(self, _arg): return True return False - def do_tables(self, arg): + def do_tables(self, _arg: str) -> None: "list the tables with their types" for entry in self.table_entries: old = "-" if entry.old_type is None else entry.old_type.name @@ -851,7 +868,7 @@ def do_tables(self, arg): desc = new if old == new else "{0}->{1}".format(old, new) self.print("{0} {1}", entry.name, desc) - def do_next(self, arg): + def do_next(self, arg: str) -> None: "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: # Find the index of the table called _arg, if any @@ -866,17 +883,20 @@ def do_next(self, arg): return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text, line, begidx, endidx): + def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] - def do_previous(self, _arg): + def do_previous(self, _arg: str) -> None: "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) - def _set_type(self, name, query, comment): + def _set_type(self, name: str, query: str, comment: str) -> None: + """ + Set the current table entry's query. + """ if len(self.table_entries) <= self.table_index: return entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] @@ -887,13 +907,15 @@ def _set_type(self, name, query, comment): columns=self.get_nonnull_columns(entry.name), ) - def _set_none(self): + def _set_none(self) -> None: + """ + Sets the current table to have no missingness applied. + """ if len(self.table_entries) <= self.table_index: return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - entry.new_type = None + self.table_entries[self.table_index].new_type = None - def do_sampled(self, arg: str): + def do_sampled(self, arg: str) -> None: """ Set the current table missingness as 'sampled', and go to the next table. "sampled 3000" means sample 3000 rows at random and choose the missingness @@ -903,7 +925,7 @@ def do_sampled(self, arg: str): if len(self.table_entries) <= self.table_index: self.print("Error! not on a table") return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + entry = self.table_entries[self.table_index] if arg == "": count = 1000 elif arg.isdecimal(): @@ -926,7 +948,7 @@ def do_sampled(self, arg: str): self.print("Table {} set to sampled missingness", self.table_name()) self.next_table() - def do_none(self, _arg): + def do_none(self, _arg: str) -> None: "Set the current table to have no missingness, and go to the next table" self._set_none() self.print("Table {} set to have no missingness", self.table_name()) @@ -934,8 +956,8 @@ def do_none(self, _arg): def update_missingness( - src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping -): + src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] +) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() return mc.config @@ -943,17 +965,28 @@ def update_missingness( @dataclass class GeneratorInfo: + """ + A generator and the columns it assigns to. + """ columns: list[str] gen: Generator | None @dataclass class GeneratorCmdTableEntry(TableEntry): + """ + List of generators set for a table. + Includes the original setting and the currently configured + generators. + """ old_generators: list[GeneratorInfo] new_generators: list[GeneratorInfo] class GeneratorCmd(DbCmd): + """ + Interactive command shell for setting generators. + """ intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the current column, then command 'compare' to see how these perform @@ -1052,21 +1085,40 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping - ): + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + ) -> None: + """ + Initialise a GeneratorCmd + :param src_dsn: connection address for source database + :param src_schema: database schema name + :param metadata: SQLAlchemy metadata for the source database + :param config: Configuration loaded from ``config.yaml`` + """ super().__init__(src_dsn, src_schema, metadata, config) self.generator_index = 0 self.generators_valid_columns = None self.set_prompt() - def set_table_index(self, index): + @property + def table_entries(self) -> list[GeneratorCmdTableEntry]: + return cast(GeneratorCmdTableEntry, self._table_entries) + + def set_table_index(self, index: int) -> bool: + """ + Moves to a new table. + :param index: table index to move to. + """ ret = super().set_table_index(index) if ret: self.generator_index = 0 self.set_prompt() return ret - def previous_table(self): + def previous_table(self) -> bool: + """ + Move to the table before the current one. + :return: True if there is a previous table to go to. + """ ret = self.set_table_index(self.table_index - 1) if ret: table = self.get_table() @@ -1082,29 +1134,44 @@ def previous_table(self): return ret def get_table(self) -> GeneratorCmdTableEntry | None: + """ + Get the current table entry. + """ if self.table_index < len(self.table_entries): return self.table_entries[self.table_index] return None def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + """ + Gets a pair; the table name then the generator information. + """ if self.table_index < len(self.table_entries): - entry: GeneratorCmdTableEntry = self.table_entries[self.table_index] + entry = self.table_entries[self.table_index] if self.generator_index < len(entry.new_generators): return (entry.name, entry.new_generators[self.generator_index]) return (entry.name, None) return (None, None) def get_column_names(self) -> list[str]: + """ + Gets the (unqualified) names for all the current columns. + """ (_, generator_info) = self.get_table_and_generator() return generator_info.columns if generator_info else [] def column_metadata(self) -> list[Column]: + """ + Gets the metadata for all the current columns. + """ table = self.table_metadata() if table is None: return [] return [table.columns[name] for name in self.get_column_names()] - def set_prompt(self): + def set_prompt(self) -> None: + """ + Set the prompt according to the current table, column and generator. + """ (table_name, gen_info) = self.get_table_and_generator() if table_name is None: self.prompt = "(generators) " @@ -1119,13 +1186,18 @@ def set_prompt(self): gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - def _remove_auto_src_stats(self) -> list[dict[str, any]]: + def _remove_auto_src_stats(self) -> list[dict[str, Any]]: + """ + Remove all automatic source stats (which we assume is + every source stats query whose name begins with ``auto__`)""" return self._remove_prefix_src_stats("auto__") def _copy_entries(self) -> None: + """ + Set generator and query information in the configuration. + """ src_stats = self._remove_auto_src_stats() - tes: list[GeneratorCmdTableEntry] = self.table_entries - for entry in tes: + for entry in self.table_entries: rgs = [] new_gens: list[Generator] = [] for generator in entry.new_generators: @@ -1173,7 +1245,7 @@ def _copy_entries(self) -> None: self.config["src-stats"] = src_stats def _find_old_generator( - self, entry: GeneratorCmdTableEntry, columns + self, entry: GeneratorCmdTableEntry, columns: Iterable[list] ) -> Generator | None: """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) @@ -1182,12 +1254,12 @@ def _find_old_generator( return gen.gen return None - def do_quit(self, arg): + def do_quit(self, arg: str) -> bool: "Check the updates, save them if desired and quit the configurer." count = 0 for entry in self.table_entries: header_shown = False - g_entry: GeneratorCmdTableEntry = entry + g_entry = cast(GeneratorCmdTableEntry, entry) for gen in g_entry.new_generators: old_gen = self._find_old_generator(g_entry, gen.columns) new_gen = None if gen is None else gen.gen @@ -1215,19 +1287,20 @@ def do_quit(self, arg): return True return False - def do_tables(self, arg): + def do_tables(self, arg: str) -> None: "list the tables" - for entry in self.table_entries: + for t_entry in self.table_entries: + entry = cast(GeneratorCmdTableEntry, t_entry) gen_count = len(entry.new_generators) how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" self.print("{0} ({1})", entry.name, how_many) - def do_list(self, arg): + def do_list(self, arg: str) -> None: "list the generators in the current table" if len(self.table_entries) <= self.table_index: self.print("Error: no table {0}", self.table_index) return - g_entry: GeneratorCmdTableEntry = self.table_entries[self.table_index] + g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) table = self.table_metadata() for gen in g_entry.new_generators: old_gen = self._find_old_generator(g_entry, gen.columns) @@ -1245,11 +1318,11 @@ def do_list(self, arg): primary = "[primary-key]" self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) - def do_columns(self, _arg): + def do_columns(self, _arg: str) -> None: "Report the column names and metadata" self.report_columns() - def do_info(self, _arg): + def do_info(self, _arg: str) -> None: "Show information about the current column" for cm in self.column_metadata(): self.print( @@ -1279,14 +1352,14 @@ def _get_table_index(self, table_name: str) -> int | None: return n return None - def _get_generator_index(self, table_index, column_name): - entry: GeneratorCmdTableEntry = self.table_entries[table_index] + def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + entry = self.table_entries[table_index] for n, gen in enumerate(entry.new_generators): if column_name in gen.columns: return n return None - def go_to(self, target): + def go_to(self, target: str) -> bool: parts = target.split(".", 1) table_index = self._get_table_index(parts[0]) if table_index is None: @@ -1310,7 +1383,7 @@ def go_to(self, target): self.set_prompt() return True - def do_next(self, arg): + def do_next(self, arg: str) -> None: """ Go to the next generator. Or go to a named table: 'next tablename'. @@ -1322,14 +1395,14 @@ def do_next(self, arg): else: self._go_next() - def do_n(self, arg): + def do_n(self, arg: str) -> None: """Synonym for next""" self.do_next(arg) - def complete_n(self, text: str, line: str, begidx: int, endidx: int): + def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: return self.complete_next(text, line, begidx, endidx) - def _go_next(self): + def _go_next(self) -> None: table = self.get_table() if table is None: self.print("No more tables") @@ -1340,7 +1413,7 @@ def _go_next(self): self.generator_index = next_gi self.set_prompt() - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): + def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: parts = text.split(".", 1) first_part = parts[0] if 1 < len(parts): @@ -1348,7 +1421,7 @@ def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): table_index = self._get_table_index(first_part) if table_index is None: return [] - table_entry: GeneratorCmdTableEntry = self.table_entries[table_index] + table_entry = self.table_entries[table_index] return [ f"{first_part}.{column}" for gen in table_entry.new_generators @@ -1374,7 +1447,7 @@ def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): column_names = [] return table_names + column_names - def do_previous(self, _arg): + def do_previous(self, _arg: str) -> None: """Go to the previous generator""" if self.generator_index == 0: self.previous_table() @@ -1382,17 +1455,24 @@ def do_previous(self, _arg): self.generator_index -= 1 self.set_prompt() - def do_b(self, arg): + def do_b(self, arg: str) -> None: """Synonym for previous""" self.do_previous(arg) def _generators_valid(self) -> bool: + """ + Return True if the self.generators property is still correct for the + table and columns currently being examined. + """ return self.generators_valid_columns == ( self.table_index, self.get_column_names(), ) def _get_generator_proposals(self) -> list[Generator]: + """ + Get a list of acceptable generators, sorted by decreasing fit to the actual data. + """ if not self._generators_valid(): self.generators = None if self.generators is None: @@ -1406,7 +1486,10 @@ def _get_generator_proposals(self) -> list[Generator]: ) return self.generators - def _print_privacy(self): + def _print_privacy(self) -> None: + """ + Print the privacy status of the current table. + """ table = self.table_metadata() if table is None: return @@ -1419,7 +1502,7 @@ def _print_privacy(self): return self.print(self.SECONDARY_PRIVATE_TEXT, pfks) - def do_compare(self, arg: str): + def do_compare(self, arg: str) -> None: """ Compare the real data with some generators. @@ -1448,11 +1531,11 @@ def do_compare(self, arg: str): self._print_values_queried(table_name, n, gen) self.print_table_by_columns(comparison) - def do_c(self, arg): + def do_c(self, arg: str) -> None: """Synonym for compare.""" self.do_compare(arg) - def _print_values_queried(self, table_name: str, n: int, gen: Generator): + def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: """ Print the values queried from the database for this generator. """ @@ -1478,7 +1561,7 @@ def _print_custom_queries(self, gen: Generator) -> None: cqs = gen.custom_queries() if not cqs: return - cq_key2args = {} + cq_key2args: dict[str, Any] = {} nominal = gen.nominal_kwargs() actual = gen.actual_kwargs() self._get_custom_queries_from( @@ -1493,7 +1576,7 @@ def _print_custom_queries(self, gen: Generator) -> None: cq_key2args[cq_key], ) - def _get_custom_queries_from(self, out, nominal, actual): + def _get_custom_queries_from(self, out: dict[str, Any], nominal: Any, actual: Any) -> None: if type(nominal) is str: src_stat_groups = self.SRC_STAT_RE.search(nominal) if src_stat_groups: @@ -1524,7 +1607,7 @@ def _get_aggregate_query( return None return f"SELECT {', '.join(clauses)} FROM {table_name}" - def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: + def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: """ Prints the select aggregate query and all the values it gets in this case. """ @@ -1554,7 +1637,7 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: select_q = self._get_aggregate_query([gen], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) - def _get_column_data(self, count: int, to_str=repr): + def _get_column_data(self, count: int, to_str: Callable[[Any], str]=repr) -> list[list[str]]: columns = self.get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) @@ -1566,7 +1649,7 @@ def _get_column_data(self, count: int, to_str=repr): ) return [[to_str(x) for x in xs] for xs in result.all()] - def do_propose(self, _arg): + def do_propose(self, _arg: str) -> None: """ Display a list of possible generators for this column. @@ -1585,8 +1668,8 @@ def do_propose(self, _arg): if not gens: self.print(self.PROPOSE_NOTHING) for index, gen in enumerate(gens): - fit = gen.fit() - if fit is None: + fit = gen.fit(-1) + if fit == -1: fit_s = "(no fit)" elif fit < 100: fit_s = f"(fit: {fit:.3g})" @@ -1600,7 +1683,7 @@ def do_propose(self, _arg): sample="; ".join(map(repr, gen.generate_data(limit))), ) - def do_p(self, arg): + def do_p(self, arg: str) -> None: """Synonym for propose""" self.do_propose(arg) @@ -1610,7 +1693,7 @@ def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: return gen return None - def do_set(self, arg: str): + def do_set(self, arg: str) -> None: """ Set one of the proposals as a generator. Takes a single integer argument. @@ -1619,6 +1702,7 @@ def do_set(self, arg: str): self.print("Please run 'propose' before 'set '") return gens = self._get_generator_proposals() + new_gen: Generator | None if arg.isdigit(): index = int(arg) if index < 1: @@ -1639,7 +1723,10 @@ def do_set(self, arg: str): self.set_generator(new_gen) self._go_next() - def set_generator(self, gen: Generator): + def set_generator(self, gen: Generator | None) -> None: + """ + Set the current column's generator. + """ (table, gen_info) = self.get_table_and_generator() if table is None: self.print("Error: no table") @@ -1649,23 +1736,23 @@ def set_generator(self, gen: Generator): return gen_info.gen = gen - def do_s(self, arg): + def do_s(self, arg: str) -> None: """Synonym for set""" self.do_set(arg) - def do_unset(self, _arg): + def do_unset(self, _arg: str) -> None: """ Remove any generator set for this column. """ self.set_generator(None) self._go_next() - def do_merge(self, arg: str): + def do_merge(self, arg: str) -> None: """Add this column(s) to the specified column(s), so one generator covers them all.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return @@ -1713,9 +1800,9 @@ def do_merge(self, arg: str): table_entry.new_generators = new_new_generators self.set_prompt() - def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): + def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: return [] return [ @@ -1726,12 +1813,12 @@ def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): if column.startswith(last_arg) ] - def do_unmerge(self, arg: str): + def do_unmerge(self, arg: str) -> None: """Remove this column(s) from this generator, make them a separate generator.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return @@ -1766,9 +1853,9 @@ def do_unmerge(self, arg: str): ) self.set_prompt() - def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int): + def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: return [] return [ @@ -1782,9 +1869,22 @@ def update_config_generators( src_dsn: str, src_schema: str, metadata: MetaData, - config: Mapping, + config: Mapping[str, Any], spec_path: Path | None, -): +) -> Mapping[str, Any]: + """ + Update configuration with the specification from a CSV file. + The specification is a headerless CSV file with columns: Table name, + Column name (or space-separated list of column names), Generator + name required, Second choice generator name, Third choice generator + name, etcetera. + :param src_dsn: Address of the source database + :param src_schema: Name of the source database schema to read from + :param metadata: SQLAlchemy representation of the source database + :param config: Existing configuration (will not be destructively updated) + :param spec_path: The path of the CSV file containing the specification + :return: Updated configuration. + """ with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: if spec_path is None: gc.cmdloop() diff --git a/datafaker/utils.py b/datafaker/utils.py index 883e0969..a94edf72 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -7,7 +7,7 @@ import sys from pathlib import Path from types import ModuleType -from typing import Any, Final, Mapping, Optional, Union +from typing import Any, Final, Mapping, Optional, TypeVar, Union import sqlalchemy import yaml @@ -38,6 +38,8 @@ Path(__file__).parent / "json_schemas/config_schema.json" ) +T = TypeVar("T") + def read_config_file(path: str) -> dict: """Read a config file, warning if it is invalid. From 3e01c574ddaa0ee07780a527f3fa451c9edbd31a Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 6 Oct 2025 21:23:17 +0100 Subject: [PATCH 08/44] More mypy fixes in interactive.py --- datafaker/interactive.py | 121 +++++++++++++++++++++++---------------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index ee9b9409..c95dd92c 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -3,11 +3,12 @@ import functools import re from abc import ABC, abstractmethod -from collections.abc import Collection, Mapping +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Any, Callable, Iterable, cast +from types import TracebackType +from typing import Any, Callable, Iterable, Optional, Type, cast from typing_extensions import Self import sqlalchemy @@ -16,6 +17,7 @@ from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( + T, create_db_engine, fk_refers_to_ignored_table, logger, @@ -30,12 +32,12 @@ import readline if not hasattr(readline, "backend"): - readline.backend = "readline" + setattr(readline, "backend", "readline") except: pass -def or_default(v, d): +def or_default(v: T | None, d: T) -> T: """Returns v if it isn't None, otherwise d.""" return d if v is None else v @@ -75,27 +77,27 @@ class AskSaveCmd(cmd.Cmd): prompt = "(yes/no/cancel) " file = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.result = "" - def do_yes(self, _arg): + def do_yes(self, _arg: str) -> bool: self.result = "yes" return True - def do_no(self, _arg): + def do_no(self, _arg: str) -> bool: self.result = "no" return True - def do_cancel(self, _arg): + def do_cancel(self, _arg: str) -> bool: self.result = "cancel" return True -def fk_column_name(fk: ForeignKey): +def fk_column_name(fk: ForeignKey) -> str: if fk_refers_to_ignored_table(fk): return f"{fk.target_fullname} (ignored)" - return fk.target_fullname + return str(fk.target_fullname) class DbCmd(ABC, cmd.Cmd): @@ -106,16 +108,16 @@ class DbCmd(ABC, cmd.Cmd): ROW_COUNT_MSG = "Total row count: {}" @abstractmethod - def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: + def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | None: ... def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ): super().__init__() - self.config: Mapping[str, Any] = config + self.config: MutableMapping[str, Any] = config self.metadata = metadata - self._table_entries: Collection[TableEntry] = [] + self._table_entries: list[TableEntry] = [] tables_config: Mapping = config.get("tables", {}) if type(tables_config) is not dict: tables_config = {} @@ -125,56 +127,79 @@ def __init__( table_config = {} entry = self.make_table_entry(name, table_config) if entry is not None: - self.table_entries.append(entry) + self._table_entries.append(entry) self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__( + self, + _exc_type: Optional[Type[BaseException]], + _exc_val: Optional[BaseException], + _exc_tb: Optional[TracebackType], + ) -> None: self.engine.dispose() - def print(self, text: str, *args, **kwargs) -> None: + def print(self, text: str, *args: Any, **kwargs: Any) -> None: print(text.format(*args, **kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]) -> None: + def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: output = PrettyTable() output.field_names = headings for row in rows: output.add_row(row) print(output) - def print_table_by_columns(self, columns: dict[str, list[str]]): + def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: output = PrettyTable() row_count = max([len(col) for col in columns.values()]) for field_name, data in columns.items(): output.add_column(field_name, data + [None] * (row_count - len(data))) print(output) - def print_results(self, result): + def print_results(self, result: sqlalchemy.CursorResult) -> None: self.print_table(list(result.keys()), [list(row) for row in result.all()]) - def ask_save(self): + def ask_save(self) -> str: + """ + Ask the user if they want to save. + :return: ``yes``, ``no`` or ``cancel``. + """ ask = AskSaveCmd() ask.cmdloop() return ask.result - def set_table_index(self, index) -> bool: + @abstractmethod + def set_prompt(self) -> None: + ... + + def set_table_index(self, index: int) -> bool: + """ + Move to a different table. + :param index: Index of the table to move to. + :return: True if there is a table with such an index to move to. + """ if 0 <= index and index < len(self._table_entries): self.table_index = index self.set_prompt() return True return False - def next_table(self, report="No more tables"): + def next_table(self, report: str="No more tables") -> bool: + """ + Move to the next table + :return: True if there is another table to move to. + """ if not self.set_table_index(self.table_index + 1): self.print(report) return False return True - def table_name(self): - return self._table_entries[self.table_index].name + def table_name(self) -> str: + """ Get the name of the current table. """ + return str(self._table_entries[self.table_index].name) def table_metadata(self) -> Table: return self.metadata.tables[self.table_name()] @@ -182,7 +207,7 @@ def table_metadata(self) -> Table: def get_column_names(self) -> list[str]: return [col.name for col in self.table_metadata().columns] - def report_columns(self): + def report_columns(self) -> None: self.print_table( ["name", "type", "primary", "nullable", "foreign key"], [ @@ -204,7 +229,7 @@ def get_table_config(self, table_name: str) -> dict[str, Any]: t = ts.get(table_name) return t if type(t) is dict else {} - def set_table_config(self, table_name: str, config: dict[str, Any]): + def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} @@ -228,7 +253,7 @@ def get_nonnull_columns(self, table_name: str) -> list[str]: if column.nullable ] - def find_entry_index_by_table_name(self, table_name) -> int | None: + def find_entry_index_by_table_name(self, table_name: str) -> int | None: return next( ( i @@ -238,13 +263,13 @@ def find_entry_index_by_table_name(self, table_name) -> int | None: None, ) - def find_entry_by_table_name(self, table_name) -> TableEntry | None: + def find_entry_by_table_name(self, table_name: str) -> TableEntry | None: for e in self._table_entries: if e.name == table_name: return e return None - def do_counts(self, _arg): + def do_counts(self, _arg: str) -> None: "Report the column names with the counts of nulls in them" if len(self._table_entries) <= self.table_index: return @@ -274,7 +299,7 @@ def do_counts(self, _arg): ], ) - def do_select(self, arg): + def do_select(self, arg: str) -> None: "Run a select query over the database and show the first 50 results" MAX_SELECT_ROWS = 50 with self.engine.connect() as connection: @@ -358,7 +383,7 @@ class TableCmd(DbCmd): NOTE_TEXT_NO_CHANGES = "You have made no changes." NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - def make_table_entry(self, name: str, table: Mapping) -> TableEntry: + def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | None: if table.get("ignore", False): return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) if table.get("vocabulary_table", False): @@ -370,14 +395,14 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @property def table_entries(self) -> list[TableCmdTableEntry]: - return cast(TableCmdTableEntry, self._table_entries) + return cast(list[TableCmdTableEntry], self._table_entries) def set_prompt(self) -> None: if self.table_index < len(self.table_entries): @@ -393,7 +418,6 @@ def set_type(self, t_type: TableType) -> None: def _copy_entries(self) -> None: for entry in self.table_entries: - entry: TableCmdTableEntry if entry.old_type != entry.new_type: table = self.get_table_config(entry.name) if ( @@ -654,7 +678,7 @@ def print_row_data(self, count: int) -> None: def update_config_tables( - src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping ) -> Mapping[str, Any]: with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() @@ -719,7 +743,7 @@ def find_missingness_query( return (src_stat.get("query", None), src_stat.get("comment", None)) return None - def make_table_entry(self, name: str, table: Mapping) -> TableEntry | None: + def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -760,7 +784,7 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry | None: ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping ): """ Initialise a MissingnessCmd. @@ -774,7 +798,7 @@ def __init__( @property def table_entries(self) -> list[MissingnessCmdTableEntry]: - return cast(MissingnessCmdTableEntry, self._table_entries) + return cast(list[MissingnessCmdTableEntry], self._table_entries) def set_prompt(self) -> None: """ @@ -956,7 +980,7 @@ def do_none(self, _arg: str) -> None: def update_missingness( - src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() @@ -1018,7 +1042,7 @@ class GeneratorCmd(DbCmd): r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' ) - def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None: + def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -1028,7 +1052,8 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None metadata_table = self.metadata.tables[table_name] columns = [str(colname) for colname in metadata_table.columns.keys()] column_set = frozenset(columns) - columns_assigned_so_far = set() + columns_assigned_so_far: set[str] = set() + new_generator_infos: list[GeneratorInfo] = [] old_generator_infos: list[GeneratorInfo] = [] for rg in table.get("row_generators", []): @@ -1054,7 +1079,7 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None ) actual_collist = [c for c in collist if c in columns] if actual_collist: - gen = PredefinedGenerator(table, rg, self.config) + gen = PredefinedGenerator(table_name, rg, self.config) new_generator_infos.append( GeneratorInfo( columns=actual_collist.copy(), @@ -1085,7 +1110,7 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: """ Initialise a GeneratorCmd @@ -1096,12 +1121,12 @@ def __init__( """ super().__init__(src_dsn, src_schema, metadata, config) self.generator_index = 0 - self.generators_valid_columns = None + self.generators_valid_columns: Optional[tuple[int, list[str]]] = None self.set_prompt() @property def table_entries(self) -> list[GeneratorCmdTableEntry]: - return cast(GeneratorCmdTableEntry, self._table_entries) + return cast(list[GeneratorCmdTableEntry], self._table_entries) def set_table_index(self, index: int) -> bool: """ @@ -1245,7 +1270,7 @@ def _copy_entries(self) -> None: self.config["src-stats"] = src_stats def _find_old_generator( - self, entry: GeneratorCmdTableEntry, columns: Iterable[list] + self, entry: GeneratorCmdTableEntry, columns: Iterable[str] ) -> Generator | None: """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) @@ -1869,7 +1894,7 @@ def update_config_generators( src_dsn: str, src_schema: str, metadata: MetaData, - config: Mapping[str, Any], + config: MutableMapping[str, Any], spec_path: Path | None, ) -> Mapping[str, Any]: """ From ea02cf3a0dfd6260b1bfe651d9043dd84f4caf1c Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 11:22:38 +0100 Subject: [PATCH 09/44] mypy clean: dump, generators, interactive, providers --- datafaker/dump.py | 5 ++-- datafaker/generators.py | 34 +++++++++++++------------- datafaker/interactive.py | 53 ++++++++++++++++++++++++++++++---------- datafaker/providers.py | 2 +- 4 files changed, 61 insertions(+), 33 deletions(-) diff --git a/datafaker/dump.py b/datafaker/dump.py index 36ca046c..5f819234 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,14 +1,15 @@ +from _csv import Writer import csv import io import sqlalchemy from sqlalchemy.schema import MetaData -from datafaker.settings import get_settings from datafaker.utils import create_db_engine, get_sync_engine, logger -def _make_csv_writer(file): +def _make_csv_writer(file: io.TextIOBase) -> Writer: + """Make the standard CSV file writer""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/generators.py b/datafaker/generators.py index 57a3c640..ff728582 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Any, Callable, Iterable, Sequence, TypeVar, Union +from typing import Any, Callable, Iterable, MutableSequence, Sequence, TypeVar, Union from typing_extensions import Self import mimesis @@ -249,7 +249,7 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -353,7 +353,7 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: return [ generator for factory in self.factories @@ -492,7 +492,7 @@ def __init__( self._end = end @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> list[Generator]: + def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> Sequence[Generator]: extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -592,7 +592,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -635,7 +635,7 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -656,7 +656,7 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -671,7 +671,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -688,7 +688,7 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -703,7 +703,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -821,13 +821,13 @@ def _get_generators_from_buckets( table_name: str, column_name: str, buckets: Buckets, - ) -> list[Generator]: + ) -> Sequence[Generator]: return [ GaussianGenerator(table_name, column_name, buckets), UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -923,7 +923,7 @@ def _get_generators_from_buckets( table_name: str, column_name: str, buckets: Buckets, - ) -> list[Generator]: + ) -> Sequence[Generator]: with engine.connect() as connection: result = connection.execute( text( @@ -1165,7 +1165,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1284,7 +1284,7 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1418,7 +1418,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1792,7 +1792,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index c95dd92c..d0355a11 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -13,12 +13,13 @@ import sqlalchemy from prettytable import PrettyTable -from sqlalchemy import Column, ForeignKey, MetaData, Table, text +from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table, text from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( T, create_db_engine, + get_sync_engine, fk_refers_to_ignored_table, logger, primary_private_fks, @@ -131,6 +132,10 @@ def __init__( self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) + @property + def sync_engine(self) -> Engine: + return get_sync_engine(self.engine) + def __enter__(self) -> Self: return self @@ -276,7 +281,7 @@ def do_counts(self, _arg: str) -> None: table_name = self.table_name() nonnull_columns = self.get_nonnull_columns(table_name) colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( "SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( @@ -302,7 +307,7 @@ def do_counts(self, _arg: str) -> None: def do_select(self, arg: str) -> None: "Run a select query over the database and show the first 50 results" MAX_SELECT_ROWS = 50 - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: try: result = connection.execute(text("SELECT " + arg)) except sqlalchemy.exc.DatabaseError as exc: @@ -330,7 +335,7 @@ def do_peek(self, arg: str) -> None: if not col_names: col_names = self.get_column_names() nonnulls = [cn + " IS NOT NULL" for cn in col_names] - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( cols=",".join(col_names), table=table_name, @@ -404,6 +409,12 @@ def __init__( def table_entries(self) -> list[TableCmdTableEntry]: return cast(list[TableCmdTableEntry], self._table_entries) + def find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + entry = super().find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(TableCmdTableEntry, entry) + def set_prompt(self) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] @@ -648,7 +659,7 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: column=column, len=min_length, ) - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( @@ -662,7 +673,7 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: self.columnize([str(x[0]) for x in result.all()]) def print_row_data(self, count: int) -> None: - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( @@ -715,7 +726,7 @@ def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> s @dataclass class MissingnessCmdTableEntry(TableEntry): old_type: MissingnessType - new_type: MissingnessType + new_type: MissingnessType | None class MissingnessCmd(DbCmd): @@ -731,7 +742,7 @@ class MissingnessCmd(DbCmd): def find_missingness_query( self, missingness_generator: Mapping - ) -> tuple[str | None, str | None] | None: + ) -> tuple[str, str] | None: """Find query and comment from src-stats for the passed missingness generator.""" kwargs = missingness_generator.get("kwargs", {}) patterns = kwargs.get("patterns", "") @@ -740,7 +751,10 @@ def find_missingness_query( key = pattern_match.group(1) for src_stat in self.config["src-stats"]: if src_stat.get("name") == key: - return (src_stat.get("query", None), src_stat.get("comment", None)) + query = src_stat.get("query", None) + if type(query) is not str: + return None + return (query, src_stat.get("comment", "")) return None def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntry | None: @@ -800,6 +814,12 @@ def __init__( def table_entries(self) -> list[MissingnessCmdTableEntry]: return cast(list[MissingnessCmdTableEntry], self._table_entries) + def find_entry_by_table_name(self, table_name: str) -> MissingnessCmdTableEntry | None: + entry = super().find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(MissingnessCmdTableEntry, entry) + def set_prompt(self) -> None: """ Sets the prompt according to the current table and missingness. @@ -814,7 +834,7 @@ def set_prompt(self) -> None: else: self.prompt = "(missingness) " - def set_type(self, t_type: TableType) -> None: + def set_type(self, t_type: MissingnessType) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type @@ -1128,6 +1148,12 @@ def __init__( def table_entries(self) -> list[GeneratorCmdTableEntry]: return cast(list[GeneratorCmdTableEntry], self._table_entries) + def find_entry_by_table_name(self, table_name: str) -> GeneratorCmdTableEntry | None: + entry = super().find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(GeneratorCmdTableEntry, entry) + def set_table_index(self, index: int) -> bool: """ Moves to a new table. @@ -1239,7 +1265,7 @@ def _copy_entries(self) -> None: else [], } ) - rg = { + rg: dict[str, Any] = { "name": generator.gen.function_name(), "columns_assigned": generator.columns, } @@ -1431,6 +1457,7 @@ def _go_next(self) -> None: table = self.get_table() if table is None: self.print("No more tables") + return next_gi = self.generator_index + 1 if next_gi == len(table.new_generators): self.next_table(self.INFO_NO_MORE_TABLES) @@ -1502,7 +1529,7 @@ def _get_generator_proposals(self) -> list[Generator]: self.generators = None if self.generators is None: columns = self.column_metadata() - gens = everything_factory().get_generators(columns, self.engine) + gens = everything_factory().get_generators(columns, self.sync_engine) gens.sort(key=lambda g: g.fit(9999)) self.generators = gens self.generators_valid_columns = ( @@ -1666,7 +1693,7 @@ def _get_column_data(self, count: int, to_str: Callable[[Any], str]=repr) -> lis columns = self.get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" diff --git a/datafaker/providers.py b/datafaker/providers.py index 1ebd5bf5..9639f1b5 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -29,7 +29,7 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed=None, **kwargs): + def __init__(self, *, seed: int | None=None, **kwargs: Any) -> None: super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} From 63f781e5c0f37c1c41b2852b7121041fcf9bf8da Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 12:42:22 +0100 Subject: [PATCH 10/44] Mypy fixed dump, interactive, main, serialize_metadata --- datafaker/dump.py | 6 +++-- datafaker/interactive.py | 18 +++++++------- datafaker/main.py | 36 +++++++++++++++------------- datafaker/serialize_metadata.py | 42 ++++++++++++++++++--------------- 4 files changed, 56 insertions(+), 46 deletions(-) diff --git a/datafaker/dump.py b/datafaker/dump.py index 5f819234..4c309110 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,14 +1,16 @@ -from _csv import Writer import csv import io +from typing import TYPE_CHECKING import sqlalchemy from sqlalchemy.schema import MetaData from datafaker.utils import create_db_engine, get_sync_engine, logger +if TYPE_CHECKING: + from _csv import Writer -def _make_csv_writer(file: io.TextIOBase) -> Writer: +def _make_csv_writer(file: io.TextIOBase) -> "Writer": """Make the standard CSV file writer""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index d0355a11..111a277a 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -113,7 +113,7 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | Non ... def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ): super().__init__() self.config: MutableMapping[str, Any] = config @@ -400,7 +400,7 @@ def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | No return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -689,7 +689,7 @@ def print_row_data(self, count: int) -> None: def update_config_tables( - src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping ) -> Mapping[str, Any]: with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() @@ -798,7 +798,7 @@ def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntr ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping ): """ Initialise a MissingnessCmd. @@ -1000,7 +1000,7 @@ def do_none(self, _arg: str) -> None: def update_missingness( - src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() @@ -1130,7 +1130,7 @@ def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTable ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: """ Initialise a GeneratorCmd @@ -1530,8 +1530,8 @@ def _get_generator_proposals(self) -> list[Generator]: if self.generators is None: columns = self.column_metadata() gens = everything_factory().get_generators(columns, self.sync_engine) - gens.sort(key=lambda g: g.fit(9999)) - self.generators = gens + sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) + self.generators = sorted_gens self.generators_valid_columns = ( self.table_index, self.get_column_names().copy(), @@ -1919,7 +1919,7 @@ def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int) -> def update_config_generators( src_dsn: str, - src_schema: str, + src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any], spec_path: Path | None, diff --git a/datafaker/main.py b/datafaker/main.py index c22f0979..65b0102a 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,11 +1,12 @@ """Entrypoint for the datafaker package.""" import asyncio +import io import json import sys from enum import Enum from importlib import metadata from pathlib import Path -from typing import Final, Optional +from typing import Any, Final, Optional import yaml from jsonschema.exceptions import ValidationError @@ -68,7 +69,7 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name, config: dict | None = None): +def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: with open(orm_file_name) as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) tables_dict = meta_dict.get("tables", {}) @@ -80,12 +81,12 @@ def load_metadata_config(orm_file_name, config: dict | None = None): return meta_dict -def load_metadata(orm_file_name, config: dict | None = None): +def load_metadata(orm_file_name: str, config: dict | None = None) -> Any: meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) -def load_metadata_for_output(orm_file_name, config: dict | None = None): +def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> Any: """ Load metadata excluding any foreign keys pointing to ignored tables. """ @@ -96,7 +97,7 @@ def load_metadata_for_output(orm_file_name, config: dict | None = None): @app.callback() def main( verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") -): +) -> None: conf_logger(verbose) @@ -202,7 +203,7 @@ def create_tables( def create_generators( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option(DF_FILENAME, help="Path to write Python generators to."), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: Optional[str] = Option( None, help=( @@ -348,11 +349,11 @@ def make_tables( @app.command() def configure_tables( - config_file: Optional[str] = Option( + config_file: str = Option( CONFIG_FILENAME, help="Path to write the configuration file to" ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), -): +) -> None: """ Interactively set tables to ignored, vocabulary or primary private. """ @@ -380,11 +381,11 @@ def configure_tables( @app.command() def configure_missing( - config_file: Optional[str] = Option( + config_file: str = Option( CONFIG_FILENAME, help="Path to write the configuration file to" ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), -): +) -> None: """ Interactively set the missingness of the generated data. """ @@ -392,11 +393,13 @@ def configure_missing( settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) config_file_path = Path(config_file) - config = {} + config: dict[str, Any] = {} if config_file_path.exists(): - config = yaml.load( + config_any = yaml.load( config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) + if type(config_any) is dict: + config = config_any metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) if config_updated is None: @@ -409,7 +412,7 @@ def configure_missing( @app.command() def configure_generators( - config_file: Optional[str] = Option( + config_file: str = Option( CONFIG_FILENAME, help="Path of the configuration file to alter" ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), @@ -417,7 +420,7 @@ def configure_generators( None, help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", ), -): +) -> None: """ Interactively set generators for column data. """ @@ -450,7 +453,7 @@ def dump_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: str = Argument(help="The table to dump"), output: str | None = Option(None, help="output CSV file name"), -): +) -> None: """Dump a whole table as a CSV file (or to the console) from the destination database.""" settings = get_settings() dst_dsn: str = settings.dst_dsn or "" @@ -459,7 +462,8 @@ def dump_data( config = read_config_file(config_file) if config_file is not None else {} metadata = load_metadata_for_output(orm_file, config) if output == None: - dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) + if isinstance(sys.stdout, io.TextIOBase): + dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return with open(output, "wt", newline="") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 303c2c76..936eb9f3 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,17 +1,21 @@ -from typing import Callable +from typing import Callable, Protocol import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.sql import schema, sqltypes +import typing from datafaker.utils import make_foreign_key_name -table_component_t = dict[str, any] -table_t = dict[str, table_component_t] +table_t = dict[str, typing.Any] -def simple(type_): +# We will change this to parsy.Parser when parsy exports its types properly +ParserType = typing.Any + + +def simple(type_: type) -> ParserType: """ Parses a simple sqltypes type. For example, simple(sqltypes.UUID) takes the string "UUID" and outputs @@ -20,14 +24,14 @@ def simple(type_): return parsy.string(type_.__name__).result(type_) -def integer(): +def integer() -> ParserType: """ Parses an integer, outputting that integer. """ return parsy.regex(r"-?[0-9]+").map(int) -def integer_arguments(): +def integer_arguments() -> ParserType: """ Parses a list of integers. The integers are surrounded by brackets and separated by @@ -38,7 +42,7 @@ def integer_arguments(): ) -def numeric_type(type_): +def numeric_type(type_: type) -> ParserType: """ Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. @@ -48,9 +52,9 @@ def numeric_type(type_): ) -def string_type(type_): +def string_type(type_: type) -> ParserType: @parsy.generate(type_.__name__) - def st_parser(): + def st_parser() -> typing.Generator[ParserType, None, typing.Any]: """ Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" or TYPE_NAME(32) COLLATE "fr" @@ -67,9 +71,9 @@ def st_parser(): return st_parser -def time_type(type_, pg_type): +def time_type(type_: type, pg_type: type) -> ParserType: @parsy.generate(type_.__name__) - def pgt_parser(): + def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: """ Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE or TYPE_NAME(32) WITH TIME ZONE @@ -125,7 +129,7 @@ def pgt_parser(): @parsy.generate -def type_parser(): +def type_parser() -> ParserType: base = yield SIMPLE_TYPE_PARSER dimensions = yield parsy.string("[]").many().map(len) if dimensions == 0: @@ -133,7 +137,7 @@ def type_parser(): return postgresql.ARRAY(base, dimensions=dimensions) -def column_to_dict(column: Column, dialect: Dialect) -> str: +def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: type_ = column.type if isinstance(type_, postgresql.DOMAIN): # Instead of creating a restricted type, we'll just use the base type. @@ -156,8 +160,8 @@ def column_to_dict(column: Column, dialect: Dialect) -> str: def dict_to_column( - table_name, - col_name, + table_name: str, + col_name: str, rep: dict, ignore_fk: Callable[[str], bool], ) -> Column: @@ -236,7 +240,7 @@ def dict_to_table( def metadata_to_dict( meta: MetaData, schema_name: str | None, engine: Engine -) -> dict[str, table_t]: +) -> dict[str, typing.Any]: """ Converts a SQL Alchemy MetaData object into a Python object ready for conversion to YAML. @@ -251,7 +255,7 @@ def metadata_to_dict( } -def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): +def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]) -> bool: """ Tell if this foreign key should be ignored because it points to an ignored table. @@ -261,10 +265,10 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): return True if fk_bits[0] not in tables_dict: return False - return tables_dict[fk_bits[0]].get("ignore", False) + return bool(tables_dict[fk_bits[0]].get("ignore", False)) -def dict_to_metadata(obj: dict, config_for_output: dict = None) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict | None=None) -> MetaData: """ Converts a dict to a SQL Alchemy MetaData object. From 3a5527b5007dd06ece3140c26e4dd97d15ec058b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 18:29:07 +0100 Subject: [PATCH 11/44] mypy clean in datafaker dir --- datafaker/make.py | 127 ++++++++++++++++++++++++++------------------- datafaker/utils.py | 100 +++++++++++++++++++++++++---------- 2 files changed, 149 insertions(+), 78 deletions(-) diff --git a/datafaker/make.py b/datafaker/make.py index 17e0d3b8..67f42f66 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -5,7 +5,9 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple +from types import TracebackType +from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple, Type +from typing_extensions import Self import pandas as pd import snsql @@ -13,15 +15,17 @@ from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider -from sqlalchemy import Engine, MetaData, UniqueConstraint, text +from sqlalchemy import CursorResult, Engine, MetaData, UniqueConstraint, text from sqlalchemy.dialects import postgresql -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table -from sqlalchemy.sql import sqltypes, type_api +from sqlalchemy.sql import Executable, sqltypes, type_api from datafaker import providers from datafaker.settings import get_settings from datafaker.utils import ( + MaybeAsyncEngine, create_db_engine, download_table, get_flag, @@ -90,6 +94,18 @@ def make_column_choices( ] +@dataclass +class _PrimaryConstraint: + """ + Describes a Uniqueness constraint for when multiple + columns in a table comprise the primary key. Not a + real constraint, but enough to write df.py. + """ + + columns: list[Column] + name: str + + @dataclass class TableGeneratorInfo: """Contains the df.py content related to regular tables.""" @@ -100,7 +116,9 @@ class TableGeneratorInfo: column_choices: list[ColumnChoice] rows_per_pass: int row_gens: list[RowGeneratorInfo] = field(default_factory=list) - unique_constraints: list[UniqueConstraint] = field(default_factory=list) + unique_constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = field( + default_factory=list + ) @dataclass @@ -112,7 +130,7 @@ class StoryGeneratorInfo: num_stories_per_pass: int -def _render_value(v) -> str: +def _render_value(v: Any) -> str: if type(v) is list: return "[" + ", ".join(_render_value(x) for x in v) + "]" if type(v) is set: @@ -150,7 +168,7 @@ def _get_row_generator( ) -> tuple[list[RowGeneratorInfo], list[str]]: """Get the row generators information, for the given table.""" row_gen_info: list[RowGeneratorInfo] = [] - config: list[dict[str, Any]] = get_property(table_config, "row_generators", []) + config: list[Mapping[str, Any]] = get_property(table_config, "row_generators", []) columns_covered = [] for gen_conf in config: name: str = gen_conf["name"] @@ -220,14 +238,14 @@ def _get_default_generator(column: Column) -> RowGeneratorInfo: ( variable_names, generator_function, - generator_arguments, + generator_kwargs, ) = _get_provider_for_column(column) return RowGeneratorInfo( primary_key=column.primary_key, variable_names=variable_names, function_call=_get_function_call( - function_name=generator_function, keyword_arguments=generator_arguments + function_name=generator_function, keyword_arguments=generator_kwargs ), ) @@ -238,13 +256,14 @@ def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: that limit its range to the permitted scale. """ column_type = column.type - if column_type.scale is None: + scale = getattr(column_type, "scale", None) + if scale is None: return ("generic.numeric.float_number", {}) return ( "generic.numeric.float_number", { - "start": 0, - "end": 10**column_type.scale - 1, + "start": "0", + "end": str(10**scale - 1), }, ) @@ -284,7 +303,7 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: @dataclass class GeneratorInfo: # Name or function to generate random objects of this type (not using summary data) - generator: str | Callable[[Column], str] + generator: str | Callable[[Column], tuple[str, dict[str, str]]] # SQL query that gets the data to supply as arguments to the generator # ({column} and {table} will be interpolated) summary_query: str | None = None @@ -298,13 +317,18 @@ class GeneratorInfo: choice: bool = False -def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: +def get_result_mappings( + info: GeneratorInfo, results: CursorResult +) -> dict[str, Any] | None: """ Gets a mapping from the results of a database query as a Python dictionary converted according to the GeneratorInfo provided. """ - kw = {} - for k, v in results.mappings().first().items(): + kw: dict[str, Any] = {} + mapping = results.mappings().first() + if mapping is None: + return kw + for k, v in mapping.items(): if v is None: return None conv_fn = info.arg_types.get(k, float) @@ -374,7 +398,7 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: def _get_generator_for_column( column_t: type, -) -> str | Callable[[type_api.TypeEngine], tuple[str, dict[str, str]]]: +) -> str | Callable[[Column], tuple[str, dict[str, str]]] | None: """ Gets a generator from a column type. @@ -386,7 +410,7 @@ def _get_generator_for_column( return None if info is None else info.generator -def _get_generator_and_arguments(column: Column) -> tuple[str, dict[str, str]]: +def _get_generator_and_arguments(column: Column) -> tuple[str | None, dict[str, str]]: """ Gets the generator and its arguments from the column type, returning a tuple of a string representing the generator callable and a dict of @@ -442,18 +466,6 @@ def _constraint_sort_key(constraint: UniqueConstraint) -> str: ) -class _PrimaryConstraint: - """ - Describes a Uniqueness constraint for when multiple - columns in a table comprise the primary key. Not a - real constraint, but enough to write df.py. - """ - - def __init__(self, *columns: Column, name: str): - self.name = name - self.columns = columns - - def _get_generator_for_table( table_config: Mapping[str, Any], table: Table, @@ -468,10 +480,12 @@ def _get_generator_for_table( key=_constraint_sort_key, ) primary_keys = [c for c in table.columns if c.primary_key] + constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = unique_constraints if 1 < len(primary_keys): - unique_constraints.append( - _PrimaryConstraint(*primary_keys, name=f"{table.name}_primary_key") + primary_constraint = _PrimaryConstraint( + columns=primary_keys, name=f"{table.name}_primary_key" ) + constraints = unique_constraints + [primary_constraint] column_choices = make_column_choices(table_config) if column_choices: nonnull_columns = { @@ -487,7 +501,7 @@ def _get_generator_for_table( nonnull_columns=nonnull_columns, column_choices=column_choices, rows_per_pass=get_property(table_config, "num_rows_per_pass", 1), - unique_constraints=unique_constraints, + unique_constraints=constraints, ) row_gen_info_data, columns_covered = _get_row_generator(table_config) @@ -525,7 +539,7 @@ def make_vocabulary_tables( overwrite_files: bool, compress: bool, table_names: set[str] | None = None, -): +) -> None: """ Extracts the data from the source database for each vocabulary table. @@ -660,8 +674,8 @@ def _generate_vocabulary_table( table: Table, engine: Engine, overwrite_files: bool = False, - compress=False, -): + compress: bool = False, +) -> None: """ Pulls data out of the source database to make a vocabulary YAML file """ @@ -712,33 +726,42 @@ def reflect_if(table_name: str, _: Any) -> bool: class DbConnection: - def __init__(self, engine): + def __init__(self, engine: MaybeAsyncEngine) -> None: + """ + Initialise an unopened database connection. + + Could be synchronous or asynchronous. + """ self._engine = engine + self._connection: Connection | AsyncConnection - async def __aenter__(self): + async def __aenter__(self) -> Self: if isinstance(self._engine, AsyncEngine): self._connection = await self._engine.connect() else: self._connection = self._engine.connect() return self - async def __aexit__(self, _type, _value, _tb): - if isinstance(self._engine, AsyncEngine): + async def __aexit__( + self, + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: + if isinstance(self._connection, AsyncConnection): await self._connection.close() - else: - self._connection.close() + self._connection.close() - async def execute_raw_query(self, query): - if isinstance(self._engine, AsyncEngine): + async def execute_raw_query(self, query: Executable) -> CursorResult: + if isinstance(self._connection, AsyncConnection): return await self._connection.execute(query) - else: - return self._connection.execute(query) + return self._connection.execute(query) - async def table_row_count(self, table_name: str): + async def table_row_count(self, table_name: str) -> int: with await self.execute_raw_query( text(f"SELECT COUNT(*) FROM {table_name}") ) as result: - return result.scalar_one() + return int(result.scalar_one()) async def execute_query(self, query_block: Mapping[str, Any]) -> Any: """Execute query in query_block.""" @@ -766,19 +789,19 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: return final_result -def fix_type(value): +def fix_type(value: Any) -> Any: if type(value) is decimal.Decimal: return float(value) return value -def fix_types(dics): +def fix_types(dics: list[dict]) -> list[dict]: return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] async def make_src_stats( dsn: str, config: Mapping, metadata: MetaData, schema_name: Optional[str] = None -) -> dict[str, list[dict]]: +) -> dict[str, dict[str, Any]]: """Run the src-stats queries specified by the configuration. Query the src database with the queries in the src-stats block of the `config` @@ -801,7 +824,7 @@ async def make_src_stats( async def make_src_stats_connection( config: Mapping, db_conn: DbConnection, metadata: MetaData -): +) -> dict[str, dict[str, Any]]: date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/utils.py b/datafaker/utils.py index a94edf72..950f061d 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -2,12 +2,23 @@ import ast import gzip import importlib.util +import io import json import logging import sys from pathlib import Path from types import ModuleType -from typing import Any, Final, Mapping, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + Final, + Generator, + Iterable, + Mapping, + Optional, + TypeVar, + Union, +) import sqlalchemy import yaml @@ -39,6 +50,7 @@ ) T = TypeVar("T") +_K = TypeVar("_K") def read_config_file(path: str) -> dict: @@ -76,16 +88,18 @@ def import_file(file_path: str) -> ModuleType: ModuleType """ spec = importlib.util.spec_from_file_location("df", file_path) + if spec is None or spec.loader is None: + raise Exception(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module -def open_file(file_name): +def open_file(file_name: str | Path) -> io.BufferedWriter: return Path(file_name).open("wb") -def open_compressed_file(file_name): +def open_compressed_file(file_name: str | Path) -> gzip.GzipFile: return gzip.GzipFile(file_name, "wb") @@ -211,14 +225,14 @@ class StdoutHandler(logging.Handler): We aren't using StreamHandler because that confuses typer.testing.CliRunner """ - def flush(self): + def flush(self) -> None: self.acquire() try: sys.stdout.flush() finally: self.release() - def emit(self, record): + def emit(self, record: Any) -> None: try: msg = self.format(record) sys.stdout.write(msg + "\n") @@ -235,14 +249,14 @@ class StderrHandler(logging.Handler): We aren't using StreamHandler because that confuses typer.testing.CliRunner """ - def flush(self): + def flush(self) -> None: self.acquire() try: sys.stderr.flush() finally: self.release() - def emit(self, record): + def emit(self, record: Any) -> None: try: msg = self.format(record) sys.stderr.write(msg + "\n") @@ -279,17 +293,17 @@ def conf_logger(verbose: bool) -> None: logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) -def get_flag(maybe_dict, key): +def get_flag(maybe_dict: Any, key: Any) -> Any: """Returns maybe_dict[key] or False if that doesn't exist""" return type(maybe_dict) is dict and maybe_dict.get(key, False) -def get_property(maybe_dict, key, default): +def get_property(maybe_dict: Mapping[_K, Any], key: _K, default: T) -> T: """Returns maybe_dict[key] or default if that doesn't exist""" return maybe_dict.get(key, default) if type(maybe_dict) is dict else default -def fk_refers_to_ignored_table(fk: ForeignKey): +def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: """ Does this foreign key refer to a table that is configured as ignore in config.yaml """ @@ -300,7 +314,7 @@ def fk_refers_to_ignored_table(fk: ForeignKey): return False -def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint): +def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: """ Does this foreign key constraint refer to a table that is configured as ignore in config.yaml """ @@ -331,7 +345,8 @@ def table_is_private(config: Mapping, table_name: str) -> bool: if type(ts) is not dict: return False t = ts.get(table_name, {}) - return t.get("primary_private", False) + ret = t.get("primary_private", False) + return ret if type(ret) is bool else False def primary_private_fks(config: Mapping, table: Table) -> list[str]: @@ -364,7 +379,11 @@ def make_foreign_key_name(table_name: str, col_name: str) -> str: return f"{table_name}_{col_name}_fkey" -def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): +def remove_vocab_foreign_key_constraints( + metadata: MetaData, + config: Mapping[str, Any], + dst_engine: Connection | Engine, +) -> None: vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] @@ -392,7 +411,20 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): raise e -def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_engine): +def reinstate_vocab_foreign_key_constraints( + metadata: MetaData, + meta_dict: Mapping[str, Any], + config: Mapping[str, Any], + dst_engine: Connection | Engine, +) -> None: + """ + Put the removed foreign keys back into the destination database. + :param metadata: The SQLAlchemy metadata for the destination database. + :param meta_dict: The ``orm.yaml`` configuration that ``metadata`` was + created from. + :param config: The ``config.yaml`` data. + :param dst_engine: The connection to the destination database. + """ vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] @@ -419,7 +451,7 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng ) -def stream_yaml(yaml_file_handle): +def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any]: """ Stream a yaml list into an iterator. @@ -441,23 +473,24 @@ def stream_yaml(yaml_file_handle): buf += line -def topological_sort(input_nodes, get_dependencies_fn): +def topological_sort( + input_nodes: Iterable[T], get_dependencies_fn: Callable[[T], set[T]] +) -> tuple[list[T], list[list[T]]]: """ Topoligically sort input_nodes and find any cycles. - Returns a pair (sorted, cycles). + Returns a pair ``(sorted, cycles)``. - 'sorted' is a list of all the elements of input_nodes sorted + ``sorted`` is a list of all the elements of input_nodes sorted so that dependencies returned by get_dependencies_fn come after nodes that depend on them. Cycles are arbitrarily broken for this. - 'cycles' is a list of lists of dependency cycles. + ``cycles`` is a list of lists of dependency cycles. - arguments: - input_nodes: an iterator of nodes to sort. Duplicates + :param input_nodes: an iterator of nodes to sort. Duplicates are discarded. - get_dependencies_fn: a function that takes an input + :param get_dependencies_fn: a function that takes an input node and returns a list of its dependencies. Any dependencies not in the input_nodes list are ignored. """ @@ -503,6 +536,21 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta return [metadata.tables[tn] for tn in sorted] +def underline_error(e: SyntaxError) -> str: + """ + Make an underline for this error. + :return: string beginning ``\n`` then spaces then ``^^^^`` + underlining the error, or a null string if this was not possible. + """ + start = e.offset + if start is None: + return "" + end = e.end_offset + if end is None or end <= start: + end = start + 1 + return "\n" + " " * start + "^" * (end - start) + + def generators_require_stats(config: Mapping) -> bool: """ Returns true if any of the arguments for any of the generators reference SRC_STATS. @@ -536,12 +584,12 @@ def generators_require_stats(config: Mapping) -> bool: except SyntaxError as e: errors.append( ( - "Syntax error in argument %d of %s: %s\n%s\n%s", + "Syntax error in argument %d of %s: %s\n%s%s", n + 1, where, e.msg, arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), + underline_error(e), ) ) for k, arg in call.get("kwargs", {}).items(): @@ -557,12 +605,12 @@ def generators_require_stats(config: Mapping) -> bool: except SyntaxError as e: errors.append( ( - "Syntax error in argument %s of %s: %s\n%s\n%s", + "Syntax error in argument %s of %s: %s\n%s%s", k, where, e.msg, arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), + underline_error(e), ) ) for error in errors: From 3fffbadd926e2061ccc330081c4f90ff7f6ea586 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 18:29:34 +0100 Subject: [PATCH 12/44] pre-commit rewrites --- datafaker/base.py | 14 +++- datafaker/create.py | 4 +- datafaker/dump.py | 1 + datafaker/generators.py | 139 +++++++++++++++++++++----------- datafaker/interactive.py | 92 ++++++++++++++++----- datafaker/providers.py | 2 +- datafaker/serialize_metadata.py | 4 +- 7 files changed, 177 insertions(+), 79 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index acdc9b95..8270ccb9 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -31,7 +31,9 @@ def zipf_weights(size: int) -> list[float]: return [1 / (n * total) for n in range(1, size + 1)] -def merge_with_constants(xs: list[T], constants_at: dict[int, T]) -> Generator[T, None, None]: +def merge_with_constants( + xs: list[T], constants_at: dict[int, T] +) -> Generator[T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. :param constants_at: A map of indices to objects that must be placed at @@ -88,7 +90,7 @@ def choice(self, a: list[T]) -> T: c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c - def zipf_choice(self, a: list[T], n: int | None=None) -> T: + def zipf_choice(self, a: list[T], n: int | None = None) -> T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] @@ -226,7 +228,9 @@ def _check_generator_name(self, name: str) -> None: raise Exception("%s is not a permitted generator", name) def alternatives( - self, alternative_configs: list[dict[str, Any]], counts: list[dict[str, int]] | None + self, + alternative_configs: list[dict[str, Any]], + counts: list[dict[str, int]] | None, ) -> Any: """ A generator that picks between other generators. @@ -270,7 +274,9 @@ def with_constants_at( logger.debug("Merging constants %s", constants_at) return list(merge_with_constants(subout, constants_at)) - def truncated_string(self, subgen_fn: Callable[..., list[T]], params: dict, length: int) -> list[T]: + def truncated_string( + self, subgen_fn: Callable[..., list[T]], params: dict, length: int + ) -> list[T]: """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: diff --git a/datafaker/create.py b/datafaker/create.py index e902ec3d..11b64a7f 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -276,9 +276,7 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values( - table_generator(dst_conn) - ) + stmt = insert(table).values(table_generator(dst_conn)) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() diff --git a/datafaker/dump.py b/datafaker/dump.py index 4c309110..2ac187ff 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from _csv import Writer + def _make_csv_writer(file: io.TextIOBase) -> "Writer": """Make the standard CSV file writer""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/generators.py b/datafaker/generators.py index ff728582..aa0b20b2 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -12,13 +12,13 @@ from functools import lru_cache from itertools import chain, combinations from typing import Any, Callable, Iterable, MutableSequence, Sequence, TypeVar, Union -from typing_extensions import Self import mimesis import mimesis.locales import sqlalchemy from sqlalchemy import Column, Connection, CursorResult, Engine, RowMapping, text from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time, TypeEngine +from typing_extensions import Self from datafaker.base import DistributionGenerator from datafaker.utils import logger @@ -123,7 +123,7 @@ def generate_data(self, count: int) -> list[Any]: Generate 'count' random data points for this column. """ - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: """ Return a value representing how well the distribution fits the real source data. @@ -249,7 +249,9 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -353,7 +355,9 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: return [ generator for factory in self.factories @@ -427,7 +431,7 @@ def nominal_kwargs(self) -> dict[str, Any]: def actual_kwargs(self) -> dict[str, Any]: return {} - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default if self._fit is None else self._fit @@ -492,7 +496,9 @@ def __init__( self._end = end @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> Sequence[Generator]: + def make_singleton( + _cls, column: Column, engine: Engine, function_name: str + ) -> Sequence[Generator]: extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -592,7 +598,9 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -635,7 +643,9 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -656,7 +666,9 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -671,7 +683,9 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -688,7 +702,9 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -703,7 +719,9 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -756,7 +774,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -827,7 +845,9 @@ def _get_generators_from_buckets( UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -906,7 +926,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -974,11 +994,11 @@ class ChoiceGenerator(Generator): def __init__( self, table_name: str, - column_name : str, + column_name: str, values: list[Any], counts: list[int], - sample_count: int | None=None, - suppress_count: int=0, + sample_count: int | None = None, + suppress_count: int = 0, ) -> None: super().__init__() self.table_name = table_name @@ -1045,7 +1065,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default if self._fit is None else self._fit @@ -1080,6 +1100,7 @@ class UniformChoiceGenerator(ChoiceGenerator): """ A generator producing values, each roughly as frequently as each other. """ + def get_estimated_counts(self, counts: list[int]) -> list[int]: return list(uniform_distribution(sum(counts), len(counts))) @@ -1106,7 +1127,7 @@ def generate_data(self, count: int) -> list[Any]: class ValueGatherer: """ Gathers values from a query of values and counts. - + The query must return columns ``v`` for a value and ``f`` for the count of how many of those values there are. These values will be gathered into a number of properties: @@ -1120,15 +1141,12 @@ class ValueGatherer: :param suppress_count: value with a count of this or fewer will be excluded from the suppressed values. """ - def __init__(self, results: CursorResult, suppress_count: int=0) -> None: + + def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: values = [] # All values found counts = [] # The number or each value - cvs: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = ( - [] - ) # All values found more than SUPPRESS_COUNT times + cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times counts_not_suppressed = [] # The number for each value not suppressed cvs_not_suppressed: list[ dict[str, Any] @@ -1165,7 +1183,9 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1186,9 +1206,15 @@ def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Gene vg = ValueGatherer(results, self.SUPPRESS_COUNT) if vg.counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), - UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), - WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), ] results = connection.execute( text( @@ -1203,9 +1229,15 @@ def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Gene vg = ValueGatherer(results, self.SUPPRESS_COUNT) if vg.counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), - UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), - WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), ] generators += [ ZipfChoiceGenerator( @@ -1284,7 +1316,9 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1304,6 +1338,7 @@ class MultivariateNormalGenerator(Generator): """ Generator of multiple values drawn from a multivariate normal distribution. """ + def __init__( self, table_name: str, @@ -1350,7 +1385,7 @@ def generate_data(self, count: int) -> list[Any]: for _ in range(count) ] - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default @@ -1418,7 +1453,9 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1564,7 +1601,9 @@ def name(self) -> str: def function_name(self) -> str: return "dist_gen.alternatives" - def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) -> dict[str, Any]: + def _nominal_kwargs_with_combinations( + self, index: int, partition: RowPartition + ) -> dict[str, Any]: count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' if not partition.included_numeric and not partition.included_choice: return { @@ -1621,7 +1660,9 @@ def custom_queries(self) -> dict[str, Any]: **partitions, } - def _actual_kwargs_with_combinations(self, partition: RowPartition) -> dict[str, Any]: + def _actual_kwargs_with_combinations( + self, partition: RowPartition + ) -> dict[str, Any]: count = sum(row["count"] for row in partition.covariates) if not partition.included_numeric and not partition.included_choice: return { @@ -1668,7 +1709,7 @@ def generate_data(self, count: int) -> list[Any]: kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default @@ -1737,12 +1778,14 @@ def __init__( class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - EMPTY_RESULT = [RowMapping( - parent=sqlalchemy.engine.result.ResultMetaData(), - processors=None, - key_to_index={"count": 0}, - data=(0,) - )] + EMPTY_RESULT = [ + RowMapping( + parent=sqlalchemy.engine.result.ResultMetaData(), + processors=None, + key_to_index={"count": 0}, + data=(0,), + ) + ] def function_name(self) -> str: return "grouped_multivariate_normal" @@ -1792,7 +1835,9 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 111a277a..3238a38e 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -9,18 +9,18 @@ from pathlib import Path from types import TracebackType from typing import Any, Callable, Iterable, Optional, Type, cast -from typing_extensions import Self import sqlalchemy from prettytable import PrettyTable from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table, text +from typing_extensions import Self from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( T, create_db_engine, - get_sync_engine, fk_refers_to_ignored_table, + get_sync_engine, logger, primary_private_fks, table_is_private, @@ -113,7 +113,11 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | Non ... def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ): super().__init__() self.config: MutableMapping[str, Any] = config @@ -192,7 +196,7 @@ def set_table_index(self, index: int) -> bool: return True return False - def next_table(self, report: str="No more tables") -> bool: + def next_table(self, report: str = "No more tables") -> bool: """ Move to the next table :return: True if there is another table to move to. @@ -203,7 +207,7 @@ def next_table(self, report: str="No more tables") -> bool: return True def table_name(self) -> str: - """ Get the name of the current table. """ + """Get the name of the current table.""" return str(self._table_entries[self.table_index].name) def table_metadata(self) -> Table: @@ -351,7 +355,9 @@ def do_peek(self, arg: str) -> None: rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] self.print_table(list(result.keys()), rows) - def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_peek( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: if len(self._table_entries) <= self.table_index: return [] return [ @@ -400,7 +406,11 @@ def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | No return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -561,7 +571,9 @@ def do_next(self, arg: str) -> None: return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] @@ -645,7 +657,9 @@ def do_data(self, arg: str) -> None: number = 48 self.print_column_data(column, number, min_length) - def complete_data(self, text: str, line: str, begidx: int, _endidx: int) -> list[str]: + def complete_data( + self, text: str, line: str, begidx: int, _endidx: int + ) -> list[str]: previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] @@ -757,7 +771,9 @@ def find_missingness_query( return (query, src_stat.get("comment", "")) return None - def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntry | None: + def make_table_entry( + self, name: str, table: Mapping + ) -> MissingnessCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -798,7 +814,11 @@ def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntr ) def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping, ): """ Initialise a MissingnessCmd. @@ -814,7 +834,9 @@ def __init__( def table_entries(self) -> list[MissingnessCmdTableEntry]: return cast(list[MissingnessCmdTableEntry], self._table_entries) - def find_entry_by_table_name(self, table_name: str) -> MissingnessCmdTableEntry | None: + def find_entry_by_table_name( + self, table_name: str + ) -> MissingnessCmdTableEntry | None: entry = super().find_entry_by_table_name(table_name) if entry is None: return None @@ -927,7 +949,9 @@ def do_next(self, arg: str) -> None: return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] @@ -1000,7 +1024,10 @@ def do_none(self, _arg: str) -> None: def update_missingness( - src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() @@ -1012,6 +1039,7 @@ class GeneratorInfo: """ A generator and the columns it assigns to. """ + columns: list[str] gen: Generator | None @@ -1023,6 +1051,7 @@ class GeneratorCmdTableEntry(TableEntry): Includes the original setting and the currently configured generators. """ + old_generators: list[GeneratorInfo] new_generators: list[GeneratorInfo] @@ -1031,6 +1060,7 @@ class GeneratorCmd(DbCmd): """ Interactive command shell for setting generators. """ + intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the current column, then command 'compare' to see how these perform @@ -1062,7 +1092,9 @@ class GeneratorCmd(DbCmd): r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' ) - def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTableEntry | None: + def make_table_entry( + self, table_name: str, table: Mapping + ) -> GeneratorCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -1130,7 +1162,11 @@ def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTable ) def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ) -> None: """ Initialise a GeneratorCmd @@ -1148,7 +1184,9 @@ def __init__( def table_entries(self) -> list[GeneratorCmdTableEntry]: return cast(list[GeneratorCmdTableEntry], self._table_entries) - def find_entry_by_table_name(self, table_name: str) -> GeneratorCmdTableEntry | None: + def find_entry_by_table_name( + self, table_name: str + ) -> GeneratorCmdTableEntry | None: entry = super().find_entry_by_table_name(table_name) if entry is None: return None @@ -1465,7 +1503,9 @@ def _go_next(self) -> None: self.generator_index = next_gi self.set_prompt() - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: parts = text.split(".", 1) first_part = parts[0] if 1 < len(parts): @@ -1628,7 +1668,9 @@ def _print_custom_queries(self, gen: Generator) -> None: cq_key2args[cq_key], ) - def _get_custom_queries_from(self, out: dict[str, Any], nominal: Any, actual: Any) -> None: + def _get_custom_queries_from( + self, out: dict[str, Any], nominal: Any, actual: Any + ) -> None: if type(nominal) is str: src_stat_groups = self.SRC_STAT_RE.search(nominal) if src_stat_groups: @@ -1689,7 +1731,9 @@ def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None select_q = self._get_aggregate_query([gen], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) - def _get_column_data(self, count: int, to_str: Callable[[Any], str]=repr) -> list[list[str]]: + def _get_column_data( + self, count: int, to_str: Callable[[Any], str] = repr + ) -> list[list[str]]: columns = self.get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) @@ -1852,7 +1896,9 @@ def do_merge(self, arg: str) -> None: table_entry.new_generators = new_new_generators self.set_prompt() - def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_merge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: @@ -1905,7 +1951,9 @@ def do_unmerge(self, arg: str) -> None: ) self.set_prompt() - def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_unmerge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: diff --git a/datafaker/providers.py b/datafaker/providers.py index 9639f1b5..03e6cbe2 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -29,7 +29,7 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed: int | None=None, **kwargs: Any) -> None: + def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 936eb9f3..d407d494 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,10 +1,10 @@ +import typing from typing import Callable, Protocol import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.sql import schema, sqltypes -import typing from datafaker.utils import make_foreign_key_name @@ -268,7 +268,7 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]) -> bool: return bool(tables_dict[fk_bits[0]].get("ignore", False)) -def dict_to_metadata(obj: dict, config_for_output: dict | None=None) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaData: """ Converts a dict to a SQL Alchemy MetaData object. From 55acf142c01d7bf2c0002aba6e84065ac3514c54 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 18:43:30 +0100 Subject: [PATCH 13/44] test_dump is mypy clean --- tests/test_dump.py | 5 +++-- tests/utils.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_dump.py b/tests/test_dump.py index 2d5ed268..7033e18f 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -1,4 +1,5 @@ """Tests for the base module.""" +import io from unittest.mock import MagicMock, call, patch from sqlalchemy.schema import MetaData @@ -17,9 +18,9 @@ class DumpTests(RequiresDBTestCase): @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: """Test dump-data.""" - TEST_OUTPUT_FILE = "test_output_file_object" + TEST_OUTPUT_FILE = io.StringIO() metadata = MetaData() - metadata.reflect(self.engine) + metadata.reflect(self.sync_engine) dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) make_csv_writer.assert_has_calls( diff --git a/tests/utils.py b/tests/utils.py index a6eb5931..d74851e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,7 @@ from datafaker.create import create_db_data_into from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import create_db_engine, import_file, sorted_non_vocabulary_tables +from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables class SysExit(Exception): @@ -115,11 +115,11 @@ class RequiresDBTestCase(DatafakerTestCase): reflected from that engine. """ - schema_name = None + schema_name: str | None = None use_asyncio = False examples_dir = "tests/examples" - dump_file_path = None - database_name = None + dump_file_path: str | None = None + database_name: str | None = None Postgresql = None @classmethod @@ -140,6 +140,7 @@ def setUp(self) -> None: schema_name=self.schema_name, use_asyncio=self.use_asyncio, ) + self.sync_engine = get_sync_engine(self.engine) self.metadata = MetaData() self.metadata.reflect(self.engine) From c3709b8f6cd09700ac42fe0d1eead918a086d7bb Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 19:10:12 +0100 Subject: [PATCH 14/44] Some mypy cleaning of tests directory --- datafaker/utils.py | 2 +- tests/test_base.py | 2 +- tests/test_create.py | 10 ++++----- tests/test_functional.py | 5 +++-- tests/test_make.py | 2 +- tests/test_providers.py | 4 ++-- tests/test_unique_generator.py | 9 ++++---- tests/test_utils.py | 14 ++++++------ tests/utils.py | 41 +++++++++++++++++++--------------- 9 files changed, 47 insertions(+), 42 deletions(-) diff --git a/datafaker/utils.py b/datafaker/utils.py index 950f061d..b34664c2 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -451,7 +451,7 @@ def reinstate_vocab_foreign_key_constraints( ) -def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any]: +def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any, None, None]: """ Stream a yaml list into an iterator. diff --git a/tests/test_base.py b/tests/test_base.py index 3f1e8cd4..411f1c09 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -46,7 +46,7 @@ def test_load(self) -> None: """Test the load method.""" vocab_gen = FileUploader(BaseTable.__table__) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: vocab_gen.load(conn) statement = select(BaseTable) rows = list(conn.execute(statement)) diff --git a/tests/test_create.py b/tests/test_create.py index 333c01a2..ca8c8f83 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -4,7 +4,7 @@ import random from collections import Counter from pathlib import Path -from typing import Any, Generator, Tuple +from typing import Any, Generator, Mapping, Tuple from unittest.mock import MagicMock, call, patch from sqlalchemy import Connection, select @@ -39,11 +39,11 @@ def test_create_vocab(self) -> None: }, } self.set_configuration(config) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) self.remove_data(config) remove_db_vocab(self.metadata, meta_dict, config) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = list(conn.execute(stmt).mappings().fetchall()) self.assertEqual(len(rows), 3) @@ -60,9 +60,9 @@ def test_create_vocab(self) -> None: def test_make_table_generators(self) -> None: """Test that we can handle column defaults in stories.""" random.seed(56) - config = {} + config: Mapping[str, Any] = {} self.generate_data(config, num_passes=2) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["string"]) rows = list(conn.execute(stmt).mappings().fetchall()) a = rows[0] diff --git a/tests/test_functional.py b/tests/test_functional.py index 418b4f96..05c3b890 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,9 +2,10 @@ import os import shutil from pathlib import Path +from typing import Any, Mapping from sqlalchemy import create_engine, inspect -from typer.testing import CliRunner +from typer.testing import CliRunner, Result from datafaker.main import app from tests.utils import RequiresDBTestCase @@ -431,7 +432,7 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args, expected_error: str = None, env={}): + def invoke(self, *args: Any, expected_error: str | None=None, env: Mapping[str, str]={}) -> Result: res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) diff --git a/tests/test_make.py b/tests/test_make.py index f43588ac..b522778f 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -45,7 +45,7 @@ def test_make_table_generators(self) -> None: }, } self.generate_data(config, num_passes=3) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = conn.execute(stmt).mappings().fetchall() for row in rows: diff --git a/tests/test_providers.py b/tests/test_providers.py index aedb693a..b5437833 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -49,7 +49,7 @@ def test_column_value_present(self) -> None: """Test the key method.""" # pylint: disable=invalid-name - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = insert(Person).values(sex="M") conn.execute(stmt) @@ -61,7 +61,7 @@ def test_column_value_present(self) -> None: def test_column_value_missing(self) -> None: """Test the generator when there are no values in the source table.""" - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: provider: providers.ColumnValueProvider = providers.ColumnValueProvider() generated_value: Any = provider.column_value(connection, Person, "sex") diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 81e9eeac..c64e281c 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -8,7 +8,6 @@ Integer, Text, UniqueConstraint, - create_engine, insert, ) from sqlalchemy.ext.declarative import declarative_base @@ -55,7 +54,7 @@ def test_unique_generator_empty_table(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: # Find a couple of different values that could be inserted, then try to do # one duplicate. test_ab1 = [True, False] @@ -83,7 +82,7 @@ def test_unique_generator_nonempty_table(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: test_ab1 = [True, False] test_ab2 = [False, False] string1 = "String 1" @@ -109,7 +108,7 @@ def test_unique_generator_multivalue_generator(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: test_val1 = (True, False, "String 1") test_val2 = (True, False, "String 2") # Conflicts on (a, b) test_val3 = (False, False, "String 1") # Conflicts on c @@ -143,7 +142,7 @@ def test_unique_generator_max_tries(self) -> None: test_val = (True, False, "String 1") mock_generator.return_value = test_val - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: self.assertEqual(uniq_ab(conn, ["a", "b", "c"], mock_generator), test_val) self.assertRaises( RuntimeError, uniq_ab, conn, ["a", "b", "c"], mock_generator diff --git a/tests/test_utils.py b/tests/test_utils.py index 2640a9e1..b34e9227 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -81,12 +81,12 @@ def test_download_table(self) -> None: """Test the download_table function.""" # pylint: disable=protected-access - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(MyTable).values({"id": 1})) conn.commit() download_table( - MyTable.__table__, self.engine, self.mytable_file_path, compress=False + MyTable.__table__, self.sync_engine, self.mytable_file_path, compress=False ) # The .strip() gets rid of any possible empty lines at the end of the file. @@ -287,7 +287,7 @@ def test_generators_require_stats(self) -> None: ) @patch("datafaker.utils.logger") - def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): + def test_testing_generators_finds_syntax_errors(self, logger: MagicMock) -> None: generators_require_stats( { "story_generators": [ @@ -309,20 +309,20 @@ def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): logger.error.assert_has_calls( [ call( - "Syntax error in argument %s of %s: %s\n%s\n%s", + "Syntax error in argument %s of %s: %s\n%s%s", "b", "story_generators[0]", "unterminated string literal (detected at line 1)", "'unclosed", - " ^", + "\n ^", ), call( - "Syntax error in argument %d of %s: %s\n%s\n%s", + "Syntax error in argument %d of %s: %s\n%s%s", 1, "tables.things.row_generators[0]", "invalid syntax", "1 2", - " ^", + "\n ^", ), ] ) diff --git a/tests/utils.py b/tests/utils.py index d74851e8..a75b4f04 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ from pathlib import Path from subprocess import run from tempfile import mkstemp -from typing import Any +from typing import Any, Mapping from unittest import TestCase, skipUnless import testing.postgresql @@ -19,7 +19,7 @@ from datafaker.create import create_db_data_into from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables +from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables, T class SysExit(Exception): @@ -47,7 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.maxDiff = None # pylint: disable=invalid-name super().__init__(*args, **kwargs) - def setUp(self): + def setUp(self) -> None: settings.get_settings.cache_clear() def assertReturnCode( # pylint: disable=invalid-name @@ -74,7 +74,7 @@ def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-nam return self.fail("".join(traceback.format_exception(result.exception))) - def assertSubset(self, set1, set2, msg=None): + def assertSubset(self, set1: set[T], set2: set[T], msg: str | None=None) -> None: """Assert a set is a (non-strict) subset. Args: @@ -117,21 +117,23 @@ class RequiresDBTestCase(DatafakerTestCase): schema_name: str | None = None use_asyncio = False - examples_dir = "tests/examples" + examples_dir = Path("tests/examples") dump_file_path: str | None = None database_name: str | None = None Postgresql = None @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.Postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True) @classmethod - def tearDownClass(cls): - cls.Postgresql.clear_cache() + def tearDownClass(cls) -> None: + if cls.Postgresql is not None: + cls.Postgresql.clear_cache() def setUp(self) -> None: super().setUp() + assert self.Postgresql is not None self.postgresql = self.Postgresql() if self.dump_file_path is not None: self.run_psql(Path(self.examples_dir) / Path(self.dump_file_path)) @@ -142,17 +144,20 @@ def setUp(self) -> None: ) self.sync_engine = get_sync_engine(self.engine) self.metadata = MetaData() - self.metadata.reflect(self.engine) + self.metadata.reflect(self.sync_engine) def tearDown(self) -> None: self.postgresql.stop() super().tearDown() @property - def dsn(self): + def dsn(self) -> str: if self.database_name: - return self.postgresql.url(database=self.database_name) - return self.postgresql.url() + url = self.postgresql.url(database=self.database_name) + else: + url = self.postgresql.url() + assert type(url) is str + return url def run_psql(self, dump_file: Path) -> None: """Run psql and pass dump_file_name as the --file option.""" @@ -187,7 +192,7 @@ def setUp(self) -> None: with os.fdopen(self.orm_fd, "w", encoding="utf-8") as orm_fh: orm_fh.write(make_tables_file(self.dsn, self.schema_name, {})) - def set_configuration(self, config) -> None: + def set_configuration(self, config: Mapping[str, Any]) -> None: """ Accepts a configuration file, writes it out. """ @@ -195,7 +200,7 @@ def set_configuration(self, config) -> None: with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) - def get_src_stats(self, config) -> dict[str, any]: + def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: """ Runs `make-stats` producing `src-stats.yaml` :return: Python dictionary representation of the contents of the src-stats file @@ -212,7 +217,7 @@ def get_src_stats(self, config) -> dict[str, any]: stats_fh.write(yaml.dump(src_stats)) return src_stats - def create_generators(self, config) -> None: + def create_generators(self, config: Mapping[str, Any]) -> None: """``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py``""" datafaker_content = make_table_generators( self.metadata, @@ -225,12 +230,12 @@ def create_generators(self, config) -> None: with os.fdopen(generators_fd, "w", encoding="utf-8") as datafaker_fh: datafaker_fh.write(datafaker_content) - def remove_data(self, config): + def remove_data(self, config: Mapping[str, Any]) -> None: """Remove source data from the DB.""" # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) - def create_data(self, config, num_passes=1): + def create_data(self, config: Mapping[str, Any], num_passes: int=1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) @@ -245,7 +250,7 @@ def create_data(self, config, num_passes=1): self.schema_name, ) - def generate_data(self, config, num_passes=1): + def generate_data(self, config: Mapping[str, Any], num_passes: int=1) -> Mapping[str, Any]: """ Replaces the DB's source data with generated data. :return: A Python dictionary representation of the src-stats.yaml file, for what it's worth. From 113c4d203c60270e758f67d15a42991dfbfa7acf Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 8 Oct 2025 12:10:27 +0100 Subject: [PATCH 15/44] Much more cleaning. mypy clean --- .github/workflows/pre-commit.yml | 2 +- datafaker/interactive.py | 2 +- datafaker/make.py | 2 +- tests/test_create.py | 4 +- tests/test_functional.py | 4 +- tests/test_interactive.py | 253 ++++++++++++++++++------------- tests/test_remove.py | 47 +++--- tests/test_unique_generator.py | 9 +- tests/utils.py | 34 +++-- 9 files changed, 205 insertions(+), 152 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3e9d2137..b07de4d1 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -9,7 +9,7 @@ on: env: # This should be the default but we'll be explicit PRE_COMMIT_HOME: ~/.caches/pre-commit - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" jobs: the_job: runs-on: ubuntu-latest diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 3238a38e..47f918cf 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -1969,7 +1969,7 @@ def update_config_generators( src_dsn: str, src_schema: str | None, metadata: MetaData, - config: MutableMapping[str, Any], + config: Mapping[str, Any], spec_path: Path | None, ) -> Mapping[str, Any]: """ diff --git a/datafaker/make.py b/datafaker/make.py index 67f42f66..ac7cdfce 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -7,7 +7,6 @@ from pathlib import Path from types import TracebackType from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple, Type -from typing_extensions import Self import pandas as pd import snsql @@ -21,6 +20,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table from sqlalchemy.sql import Executable, sqltypes, type_api +from typing_extensions import Self from datafaker import providers from datafaker.settings import get_settings diff --git a/tests/test_create.py b/tests/test_create.py index ca8c8f83..b175f070 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -39,7 +39,9 @@ def test_create_vocab(self) -> None: }, } self.set_configuration(config) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) + meta_dict = metadata_to_dict( + self.metadata, self.schema_name, self.sync_engine + ) self.remove_data(config) remove_db_vocab(self.metadata, meta_dict, config) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) diff --git a/tests/test_functional.py b/tests/test_functional.py index 05c3b890..394691b3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -432,7 +432,9 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args: Any, expected_error: str | None=None, env: Mapping[str, str]={}) -> Result: + def invoke( + self, *args: Any, expected_error: str | None = None, env: Mapping[str, str] = {} + ) -> Result: res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 284e04d0..e1de3b3a 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -3,6 +3,7 @@ import random import re from dataclasses import dataclass +from typing import Any, Iterable, Mapping, MutableMapping from unittest.mock import MagicMock, Mock, patch from sqlalchemy import insert, select @@ -19,31 +20,39 @@ class TestDbCmdMixin(DbCmd): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a TestDbCmdMixin""" super().__init__(*args, **kwargs) self.reset() - def reset(self): - self.messages: list[tuple[str, list, dict[str, any]]] = [] + def reset(self) -> None: + """Reset all the debug messages collected so far.""" + self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] self.headings: list[str] = [] self.rows: list[list[str]] = [] - self.column_items: list[str] = [] - self.columns: dict[str, list[str]] = {} + self.column_items: list[list[str]] = [] + self.columns: dict[str, list[Any]] = {} - def print(self, text: str, *args, **kwargs): + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Capture the printed message.""" self.messages.append((text, args, kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]): + def print_table(self, headings: list[str], rows: list[list[str]]) -> None: + """Capture the printed table.""" self.headings = headings self.rows = rows - def print_table_by_columns(self, columns: dict[str, list[str]]): + def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: + """Capture the printed table.""" self.columns = columns - def columnize(self, items: list[str]): - self.column_items.append(items) + def columnize(self, items: list[str] | None, displaywidth: int = 80) -> None: + """Capture the printed table.""" + if items is not None: + self.column_items.append(items) def ask_save(self) -> str: + """Quitting always works without needing to ask the user.""" return "yes" @@ -54,7 +63,7 @@ class TestTableCmd(TableCmd, TestDbCmdMixin): class ConfigureTablesTests(RequiresDBTestCase): """Testing configure-tables.""" - def _get_cmd(self, config) -> TestTableCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) @@ -67,7 +76,7 @@ class ConfigureTablesSrcTests(ConfigureTablesTests): def test_table_name_prompts(self) -> None: """Test that the prompts follow the names of the tables.""" - config = {} + config: MutableMapping[str, Any] = {} with self._get_cmd(config) as tc: table_names = list(self.metadata.tables.keys()) for t in table_names: @@ -95,7 +104,7 @@ def test_table_name_prompts(self) -> None: def test_column_display(self) -> None: """Test that we can see the names of the columns.""" - config = {} + config: MutableMapping[str, Any] = {} with self._get_cmd(config) as tc: tc.do_next("unique_constraint_test") tc.do_columns("") @@ -211,7 +220,7 @@ def test_configure_tables(self) -> None: def test_print_data(self) -> None: """Test that we can print random rows from the table and random data from columns.""" person_table = self.metadata.tables["person"] - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: person_rows = conn.execute(select(person_table)).mappings().fetchall() person_data = {row["person_id"]: row for row in person_rows} name_set = {row["name"] for row in person_rows} @@ -255,7 +264,7 @@ def test_print_data(self) -> None: set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) ) - def test_list_tables(self): + def test_list_tables(self) -> None: """Test that we can list the tables""" config = { "tables": { @@ -308,7 +317,10 @@ class ConfigureTablesInstrumentsTests(ConfigureTablesTests): database_name = "instrument" schema_name = "public" - def test_sanity_checks_both(self): + def test_sanity_checks_both(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ config = { "tables": { "model": { @@ -349,7 +361,10 @@ def test_sanity_checks_both(self): ), ) - def test_sanity_checks_warnings_only(self): + def test_sanity_checks_warnings_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ config = { "tables": { "model": { @@ -388,7 +403,10 @@ def test_sanity_checks_warnings_only(self): ), ) - def test_sanity_checks_errors_only(self): + def test_sanity_checks_errors_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ config = { "tables": { "model": { @@ -431,7 +449,7 @@ def test_sanity_checks_errors_only(self): class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): """GeneratorCmd but mocked""" - def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: + def get_proposals(self) -> dict[str, tuple[int, str, list[str]]]: """ Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples]) """ @@ -449,10 +467,11 @@ class ConfigureGeneratorsTests(RequiresDBTestCase): database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the command we are using for this test case.""" return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_null_configuration(self): + def test_null_configuration(self) -> None: """Test that the tables having null configuration does not break.""" config = { "tables": None, @@ -466,7 +485,7 @@ def test_null_configuration(self): gc.do_quit("") self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) - def test_null_table_configuration(self): + def test_null_table_configuration(self) -> None: """Test that a table having null configuration does not break.""" config = { "tables": { @@ -483,7 +502,7 @@ def test_null_table_configuration(self): def test_prompts(self) -> None: """Test that the prompts follow the names of the columns and assigned generators.""" - config = {} + config: MutableMapping[str, Any] = {} with self._get_cmd(config) as gc: for table_name, table_meta in self.metadata.tables.items(): for column_name, column_meta in table_meta.columns.items(): @@ -521,7 +540,7 @@ def test_prompts(self) -> None: ) gc.reset() - def test_set_generator_mimesis(self): + def test_set_generator_mimesis(self) -> None: """Test that we can set one generator to a mimesis generator.""" with self._get_cmd({}) as gc: TABLE = "model" @@ -538,7 +557,7 @@ def test_set_generator_mimesis(self): {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, ) - def test_set_generator_distribution(self): + def test_set_generator_distribution(self) -> None: """Test that we can set one generator to gaussian.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -571,7 +590,7 @@ def test_set_generator_distribution(self): f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", ) - def test_set_generator_distribution_directly(self): + def test_set_generator_distribution_directly(self) -> None: """Test that we can set one generator to gaussian without going through propose.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -592,7 +611,7 @@ def test_set_generator_distribution_directly(self): f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", ) - def test_set_generator_choice(self): + def test_set_generator_choice(self) -> None: """Test that we can set one generator to uniform choice.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -626,7 +645,7 @@ def test_set_generator_choice(self): f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", ) - def test_weighted_choice_generator_generates_choices(self): + def test_weighted_choice_generator_generates_choices(self) -> None: """Test that propose and compare show weighted_choice's values.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -643,7 +662,7 @@ def test_weighted_choice_generator_generates_choices(self): self.assertIn(col_heading, gc.columns) self.assertSubset(set(gc.columns[col_heading]), VALUES) - def test_merge_columns(self): + def test_merge_columns(self) -> None: """Test that we can merge columns and set a multivariate generator""" TABLE = "string" COLUMN_1 = "frequency" @@ -683,7 +702,7 @@ def test_merge_columns(self): self.assertEqual(row_gen["name"], GENERATOR) self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) - def test_unmerge_columns(self): + def test_unmerge_columns(self) -> None: """Test that we can unmerge columns and generators are removed""" TABLE = "string" COLUMN_1 = "frequency" @@ -722,7 +741,7 @@ def test_unmerge_columns(self): self.assertEqual(row_gen["name"], REMAINING_GEN) self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) - def test_old_generators_remain(self): + def test_old_generators_remain(self) -> None: """Test that we can set one generator and keep an old one.""" config = { "tables": { @@ -782,7 +801,7 @@ def test_old_generators_remain(self): "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", ) - def test_aggregate_queries_merge(self): + def test_aggregate_queries_merge(self) -> None: """ Test that we can set a generator that requires select aggregate clauses and keep an old one, resulting in a merged query. @@ -817,7 +836,7 @@ def test_aggregate_queries_merge(self): proposals = gc.get_proposals() gc.do_set(str(proposals[f"{GENERATOR}"][0])) gc.do_quit("") - row_gens: list[dict[str, any]] = gc.config["tables"]["string"][ + row_gens: list[dict[str, Any]] = gc.config["tables"]["string"][ "row_generators" ] self.assertEqual(len(row_gens), 2) @@ -850,9 +869,9 @@ def test_aggregate_queries_merge(self): select_match = re.match( r"SELECT (.*) FROM string", gc.config["src-stats"][0]["query"] ) - self.assertIsNotNone( - select_match, "src_stats[0].query is not an aggregate select" - ) + assert ( + select_match is not None + ), "src_stats[0].query is not an aggregate select" self.assertSetEqual( set(select_match.group(1).split(", ")), { @@ -863,7 +882,7 @@ def test_aggregate_queries_merge(self): }, ) - def test_next_completion(self): + def test_next_completion(self) -> None: """Test tab completion for the next command.""" with self._get_cmd({}) as gc: self.assertSetEqual( @@ -887,7 +906,7 @@ def test_next_completion(self): ) self.assertListEqual(gc.complete_next("ww", "next ww", 5, 7), []) - def test_compare_reports_privacy(self): + def test_compare_reports_privacy(self) -> None: """ Test that compare reports whether the current table is primary private, secondary private or not private. @@ -917,11 +936,11 @@ def test_compare_reports_privacy(self): self.assertEqual(text, gc.SECONDARY_PRIVATE_TEXT) self.assertSequenceEqual(args, [["model"]]) - def test_existing_configuration_remains(self): + def test_existing_configuration_remains(self) -> None: """ Test setting a generator does not remove other information. """ - config = { + config: MutableMapping[str, Any] = { "tables": { "string": { "primary_private": True, @@ -946,7 +965,7 @@ def test_existing_configuration_remains(self): self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) self.assertTrue(gc.config["tables"]["string"]["primary_private"]) - def test_empty_tables_are_not_configured(self): + def test_empty_tables_are_not_configured(self) -> None: """Test that tables marked as empty are not configured.""" config = { "tables": { @@ -969,10 +988,10 @@ class GeneratorsOutputTests(GeneratesDBTestCase): database_name = "numbers" schema_name = "public" - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_create_with_sampled_choice(self): + def test_create_with_sampled_choice(self) -> None: """Test that suppression works for choice and zipf_choice.""" table_name = "number_table" with self._get_cmd({}) as gc: @@ -1013,7 +1032,7 @@ def test_create_with_sampled_choice(self): gc.do_set(str(proposals["dist_gen.choice [sampled]"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -1028,7 +1047,7 @@ def test_create_with_sampled_choice(self): self.assertSetEqual(twos, {2, 3}) self.assertSetEqual(threes, {1, 2, 3, 4, 5}) - def test_create_with_choice(self): + def test_create_with_choice(self) -> None: """Smoke test normal choice works.""" table_name = "number_table" with self._get_cmd({}) as gc: @@ -1044,7 +1063,7 @@ def test_create_with_choice(self): gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -1056,7 +1075,7 @@ def test_create_with_choice(self): self.assertSetEqual(ones, {1, 2, 3, 4, 5}) self.assertSetEqual(twos, {1, 2, 3, 4, 5}) - def test_create_with_weighted_choice(self): + def test_create_with_weighted_choice(self) -> None: """Smoke test weighted choice.""" table_name = "number_table" with self._get_cmd({}) as gc: @@ -1077,7 +1096,8 @@ def test_create_with_weighted_choice(self): f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" ) self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 4}) + col_set: set[int] = set(gc.columns[col_heading]) + self.assertSubset(col_set, {1, 4}) gc.do_set(str(prop[0])) gc.do_next("number_table.two") gc.reset() @@ -1094,7 +1114,8 @@ def test_create_with_weighted_choice(self): gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice" self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 2, 3, 4, 5}) + col_set2: set[int] = set(gc.columns[col_heading]) + self.assertSubset(col_set2, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_next("number_table.three") gc.reset() @@ -1110,11 +1131,12 @@ def test_create_with_weighted_choice(self): gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled]" self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 2, 3, 4, 5}) + col_set3: set[int] = set(gc.columns[col_heading]) + self.assertSubset(col_set3, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -1141,62 +1163,60 @@ class ConfigureMissingnessTests(RequiresDBTestCase): database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestMissingnessCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + """We are using configure-missingness.""" return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - def test_set_missingness_to_sampled(self): + def test_set_missingness_to_sampled(self) -> None: """Test that we can set one table to sampled missingness.""" with self._get_cmd({}) as mc: TABLE = "signature_model" mc.do_next(TABLE) mc.do_counts("") self.assertListEqual( - mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})] + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] ) - self.assertListEqual(mc.rows, [["player_id", 3], ["based_on", 2]]) + # Check the counts of NULLs in each column + self.assertListEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) mc.do_sampled("") mc.do_quit("") - self.assertDictEqual( - mc.config, - { - "tables": { - TABLE: { - "missingness_generators": [ - { - "columns": ["player_id", "based_on"], - "kwargs": { - "patterns": 'SRC_STATS["missing_auto__signature_model__0"]' - }, - "name": "column_presence.sampled", - } - ] - } - }, - "src-stats": [ - { - "name": "missing_auto__signature_model__0", - "query": ( - "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" - ), - } - ], - }, + self.assertListEqual( + mc.config["tables"][TABLE]["missingness_generators"], + [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' + }, + "name": "column_presence.sampled", + } + ], + ) + self.assertEqual( + mc.config["src-stats"][0]["name"], + "missing_auto__signature_model__0", + ) + self.assertEqual( + mc.config["src-stats"][0]["query"], + ( + "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" + ), ) -class ConfigureMissingnessTests(GeneratesDBTestCase): +class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): """Testing configure-missing with generation.""" dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestMissingnessCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - def test_create_with_missingness(self): + def test_create_with_missingness(self) -> None: """Test that we can sample real missingness and reproduce it.""" random.seed(45) # Configure the missingness @@ -1208,7 +1228,7 @@ def test_create_with_missingness(self): config = mc.config self.generate_data(config, num_passes=100) # Test that each missingness pattern is present in the database - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).mappings().fetchall() patterns: set[int] = set() @@ -1227,10 +1247,11 @@ class GeneratorTests(GeneratesDBTestCase): database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """We are using configure-generators.""" return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_set_null(self): + def test_set_null(self) -> None: """Test that we can sample real missingness and reproduce it.""" with self._get_cmd({}) as gc: gc.do_next("string.position") @@ -1256,8 +1277,13 @@ def test_set_null(self): config = gc.config self.generate_data(config, num_passes=3) # Test that each missingness pattern is present in the database - with self.engine.connect() as conn: - stmt = select(self.metadata.tables["string"].c["position", "frequency"]) + with self.sync_engine.connect() as conn: + # select(self.metadata.tables["string"].c["position", "frequency"]) would be nicer + # but mypy doesn't like it + stmt = select( + self.metadata.tables["string"].c["position"], + self.metadata.tables["string"].c["frequency"], + ) rows = conn.execute(stmt).fetchall() count = 0 for row in rows: @@ -1265,7 +1291,12 @@ def test_set_null(self): self.assertEqual(row.position, 0) self.assertEqual(row.frequency, 0.0) self.assertEqual(count, 3) - stmt = select(self.metadata.tables["signature_model"].c["name", "based_on"]) + # select(self.metadata.tables["signature_model"].c["name", "based_on"]) would be nicer + # but mypy doesn't like it + stmt = select( + self.metadata.tables["signature_model"].c["name"], + self.metadata.tables["signature_model"].c["based_on"], + ) rows = conn.execute(stmt).fetchall() count = 0 for row in rows: @@ -1274,7 +1305,7 @@ def test_set_null(self): self.assertIsNone(row.based_on) self.assertEqual(count, 3) - def test_dist_gen_sampled_produces_ordered_src_stats(self): + def test_dist_gen_sampled_produces_ordered_src_stats(self) -> None: """Tests that choosing a sampled choice generator produces ordered src stats""" with self._get_cmd({}) as gc: gc.do_next("signature_model.player_id") @@ -1294,7 +1325,11 @@ def test_dist_gen_sampled_produces_ordered_src_stats(self): ] self.assertListEqual(based_ons, [1, 3, 2]) - def assertAreTruncatedTo(self, xs, length): + def assertAreTruncatedTo(self, xs: Iterable[str], length: int) -> None: + """ + Check that none of the strings are longer than ``length`` (after + removing surrounding quotes). + """ maxlen = 0 for x in xs: newlen = len(x.strip("'\"")) @@ -1302,7 +1337,7 @@ def assertAreTruncatedTo(self, xs, length): maxlen = max(maxlen, newlen) self.assertEqual(maxlen, length) - def test_varchar_ns_are_truncated(self): + def test_varchar_ns_are_truncated(self) -> None: """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" GENERATOR = "generic.text.quote" TABLE = "signature_model" @@ -1325,7 +1360,7 @@ def test_varchar_ns_are_truncated(self): gc.do_quit("") config = gc.config self.generate_data(config, num_passes=15) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[TABLE].c[COLUMN]) rows = conn.execute(stmt).scalars().fetchall() self.assertAreTruncatedTo(rows, 20) @@ -1359,7 +1394,7 @@ class Correlation(Stat): y2: float = 0 xy: float = 0 - def add(self, x: float, y: float) -> None: + def add2(self, x: float, y: float) -> None: self.n += 1 self.x += x self.x2 += x * x @@ -1386,14 +1421,16 @@ class NullPartitionedTests(GeneratesDBTestCase): schema_name = "public" def setUp(self) -> None: + """Set up the test with specific sample and suppress counts.""" super().setUp() NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the configure-generators object as our command.""" return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_create_with_null_partitioned_grouped_multivariate(self): + def test_create_with_null_partitioned_grouped_multivariate(self) -> None: """Test EAV for all columns.""" table_name = "measurement" generate_count = 800 @@ -1422,7 +1459,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.remove_data(gc.config) # let's add a vocab table without messing around with files table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(table).values({"id": 1, "name": "agreement"})) conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) conn.execute(insert(table).values({"id": 3, "name": "velocity"})) @@ -1430,7 +1467,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self): conn.execute(insert(table).values({"id": 5, "name": "matter"})) conn.commit() self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() one_count = 0 @@ -1454,19 +1491,19 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertIsNotNone(row.first_value) self.assertIsNotNone(row.second_value) self.assertIsNone(row.third_value) - two.add(row.first_value, row.second_value) + two.add2(row.first_value, row.second_value) elif row.type == 3: # negative correlation around 11.8, 12.1 self.assertIsNotNone(row.first_value) self.assertIsNotNone(row.second_value) self.assertIsNone(row.third_value) - three.add(row.first_value, row.second_value) + three.add2(row.first_value, row.second_value) elif row.type == 4: # positive correlation around 21.4, 23.4 self.assertIsNotNone(row.first_value) self.assertIsNotNone(row.second_value) self.assertIsNone(row.third_value) - four.add(row.first_value, row.second_value) + four.add2(row.first_value, row.second_value) elif row.type == 5: self.assertIn(row.third_value, {"fish", "fowl"}) self.assertIsNotNone(row.first_value) @@ -1517,7 +1554,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) - def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): + def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: """Test EAV for all columns with sampled and suppressed generation.""" table_name = "measurement" table2_name = "observation" @@ -1566,7 +1603,7 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.remove_data(gc.config) # let's add a vocab table without messing around with files table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(table).values({"id": 1, "name": "agreement"})) conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) conn.execute(insert(table).values({"id": 3, "name": "velocity"})) @@ -1574,7 +1611,7 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): conn.execute(insert(table).values({"id": 5, "name": "matter"})) conn.commit() self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() one_count = 0 @@ -1661,11 +1698,11 @@ class NonInteractiveTests(RequiresDBTestCase): ) def test_non_interactive_configure_generators( self, mock_csv_reader: MagicMock, mock_path: MagicMock - ): + ) -> None: """ test that we can set generators from a CSV file """ - config = {} + config: Mapping[str, Any] = {} spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( self.dsn, self.schema_name, self.metadata, config, spec_csv diff --git a/tests/test_remove.py b/tests/test_remove.py index bfbb787d..24286fba 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch from sqlalchemy import func, inspect, select +from sqlalchemy.engine import Connection from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict @@ -16,17 +17,16 @@ class RemoveThingsTestCase(RequiresDBTestCase): database_name = "instrument" schema_name = "public" - def count_rows(self, connection, table_name: str) -> int | None: + def count_rows(self, connection: Connection, table_name: str) -> int | None: return connection.execute( select(func.count()).select_from(self.metadata.tables[table_name]) ).scalar() @patch("datafaker.remove.get_settings") - def test_remove_data(self, mock_get_settings: MagicMock): + def test_remove_data(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) remove_db_data( self.metadata, @@ -37,9 +37,9 @@ def test_remove_data(self, mock_get_settings: MagicMock): } }, ) - with self.engine.connect() as conn: - self.assertGreater(self.count_rows(conn, "manufacturer"), 0) - self.assertGreater(self.count_rows(conn, "model"), 0) + with self.sync_engine.connect() as conn: + self.assertGreaterAndNotNone(self.count_rows(conn, "manufacturer"), 0) + self.assertGreaterAndNotNone(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) self.assertEqual(self.count_rows(conn, "string"), 0) self.assertEqual(self.count_rows(conn, "signature_model"), 0) @@ -50,7 +50,6 @@ def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: remove_db_data( @@ -67,13 +66,12 @@ def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: ) @patch("datafaker.remove.get_settings") - def test_remove_vocab(self, mock_get_settings: MagicMock): + def test_remove_vocab(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) config = { "tables": { "manufacturer": {"vocabulary_table": True}, @@ -82,7 +80,7 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): } remove_db_data(self.metadata, config) remove_db_vocab(self.metadata, meta_dict, config) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: self.assertEqual(self.count_rows(conn, "manufacturer"), 0) self.assertEqual(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) @@ -95,10 +93,11 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict( + self.metadata, self.schema_name, self.sync_engine + ) remove_db_vocab( self.metadata, meta_dict, @@ -114,19 +113,24 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: ) @patch("datafaker.remove.get_settings") - def test_remove_tables(self, mock_get_settings: MagicMock): + def test_remove_tables(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - self.assertTrue(inspect(self.engine).has_table("player")) + engine_in = inspect(self.engine) + assert engine_in is not None + assert hasattr(engine_in, "has_table") + self.assertTrue(engine_in.has_table("player")) remove_db_tables(self.metadata) - self.assertFalse(inspect(self.engine).has_table("manufacturer")) - self.assertFalse(inspect(self.engine).has_table("model")) - self.assertFalse(inspect(self.engine).has_table("player")) - self.assertFalse(inspect(self.engine).has_table("string")) - self.assertFalse(inspect(self.engine).has_table("signature_model")) + engine_out = inspect(self.engine) + assert engine_out is not None + assert hasattr(engine_out, "has_table") + self.assertFalse(engine_out.has_table("manufacturer")) + self.assertFalse(engine_out.has_table("model")) + self.assertFalse(engine_out.has_table("player")) + self.assertFalse(engine_out.has_table("string")) + self.assertFalse(engine_out.has_table("signature_model")) @patch("datafaker.remove.get_settings") def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: @@ -134,7 +138,6 @@ def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: remove_db_tables(self.metadata) diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index c64e281c..503a36f5 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -2,14 +2,7 @@ from pathlib import Path from unittest.mock import MagicMock -from sqlalchemy import ( - Boolean, - Column, - Integer, - Text, - UniqueConstraint, - insert, -) +from sqlalchemy import Boolean, Column, Integer, Text, UniqueConstraint, insert from sqlalchemy.ext.declarative import declarative_base from datafaker.unique_generator import UniqueGenerator diff --git a/tests/utils.py b/tests/utils.py index a75b4f04..4e9f2365 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,13 @@ from datafaker.create import create_db_data_into from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables, T +from datafaker.utils import ( + T, + create_db_engine, + get_sync_engine, + import_file, + sorted_non_vocabulary_tables, +) class SysExit(Exception): @@ -74,14 +80,22 @@ def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-nam return self.fail("".join(traceback.format_exception(result.exception))) - def assertSubset(self, set1: set[T], set2: set[T], msg: str | None=None) -> None: + def assertGreaterAndNotNone(self, left: float | None, right: float) -> None: + """ + Assert left is not None and greater than right + """ + if left is None: + self.fail("first argument is None") + else: + self.assertGreater(left, right) + + def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: """Assert a set is a (non-strict) subset. - Args: - set1: The asserted subset. - set2: The asserted superset. - msg: Optional message to use on failure instead of a list of - differences. + :param set1: The asserted subset. + :param set2: The asserted superset. + :param msg: Optional message to use on failure instead of a list of + differences. """ try: difference = set1.difference(set2) @@ -235,7 +249,7 @@ def remove_data(self, config: Mapping[str, Any]) -> None: # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) - def create_data(self, config: Mapping[str, Any], num_passes: int=1) -> None: + def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) @@ -250,7 +264,9 @@ def create_data(self, config: Mapping[str, Any], num_passes: int=1) -> None: self.schema_name, ) - def generate_data(self, config: Mapping[str, Any], num_passes: int=1) -> Mapping[str, Any]: + def generate_data( + self, config: Mapping[str, Any], num_passes: int = 1 + ) -> Mapping[str, Any]: """ Replaces the DB's source data with generated data. :return: A Python dictionary representation of the src-stats.yaml file, for what it's worth. From 10b02c5f243964345d66181d8e87e5db568fc513 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 8 Oct 2025 18:14:40 +0100 Subject: [PATCH 16/44] precommit cleanup, NullPartitionedGrouped fix --- .github/workflows/tests.yml | 2 +- .pylintrc | 2 +- .readthedocs.yaml | 2 +- datafaker/create.py | 39 +++-- datafaker/dump.py | 3 +- datafaker/generators.py | 27 ++- datafaker/interactive.py | 36 ++-- datafaker/main.py | 25 ++- datafaker/make.py | 8 +- datafaker/providers.py | 1 + datafaker/serialize_metadata.py | 83 ++++++--- mypy.ini | 2 +- tests/test_functional.py | 29 +++- tests/test_interactive.py | 295 +++++++++++++++++--------------- tests/test_utils.py | 1 + 15 files changed, 334 insertions(+), 221 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 137557a7..75f45f03 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: - main env: # This should be the default but we'll be explicit - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" jobs: the_job: runs-on: ubuntu-latest diff --git a/.pylintrc b/.pylintrc index d97276b9..22a92bd7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -53,7 +53,7 @@ persistent=yes # Min Python version to use for version dependend checks. Will default to the # version used to run pylint. -py-version=3.9 +py-version=3.10 # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 91942b1b..29cdf780 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # You can also specify other tool versions: # nodejs: "19" # rust: "1.64" diff --git a/datafaker/create.py b/datafaker/create.py index 11b64a7f..f11a0dd3 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,6 +1,5 @@ """Functions and classes to create and populate the target database.""" import pathlib -import random from collections import Counter from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple @@ -125,6 +124,19 @@ def create_db_data_into( db_dsn: str, schema_name: str | None, ) -> RowCounts: + """ + Populate the database. + + :param sorted_tables: The table names to populate, sorted so that foreign + keys' targets are populated before the foreign keys themselves. + :param table_generator_dict: A mapping of table names to the generators + used to make data for them. + :param story_generator_list: A list of story generators to be run after the + table generators on each pass. + :param num_passes: Number of passes to perform. + :param db_dsn: Connection string for the destination database. + :param schema_name: Destination schema name. + """ dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) row_counts: Counter[str] = Counter() @@ -140,6 +152,8 @@ def create_db_data_into( class StoryIterator: + """Iterates through all the rows produced by all the stories.""" + def __init__( self, stories: Iterable[tuple[str, Story]], @@ -147,6 +161,7 @@ def __init__( table_generator_dict: Mapping[str, TableGenerator], dst_conn: Connection, ): + """Initialise a Story Iterator.""" self._stories: Iterator[tuple[str, Story]] = iter(stories) self._table_dict: Mapping[str, Table] = table_dict self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict @@ -162,27 +177,31 @@ def __init__( def is_ended(self) -> bool: """ - Do we have another row to process? + Check if we have another row to process. + If so, insert() can be called. """ return self._table_name is None def has_table(self, table_name: str) -> bool: - """ - Do we have a row for table table_name? - """ + """Check if we have a row for table ``table_name``.""" return table_name == self._table_name def table_name(self) -> str | None: """ - The name of the current table (or None if no more stories to process) + Get the name of the current table. + + :return: The table name, or None if there are no more stories + to process. """ return self._table_name def insert(self) -> None: """ - Perform the insert. Call this after __init__ or next, and after checking - that is_ended returns False. + Put the row in the table. + + Call this after __init__ or next, and after checking that is_ended + returns False. """ if self._table_name is None: raise StopIteration("StoryIterator.insert after is_ended") @@ -210,9 +229,7 @@ def insert(self) -> None: cursor.close() def next(self) -> None: - """ - Advance to the next table row. - """ + """Advance to the next row.""" while True: try: if self._final_values is None: diff --git a/datafaker/dump.py b/datafaker/dump.py index 2ac187ff..95f251a7 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,3 +1,4 @@ +""" Data dumping functions. """ import csv import io from typing import TYPE_CHECKING @@ -12,7 +13,7 @@ def _make_csv_writer(file: io.TextIOBase) -> "Writer": - """Make the standard CSV file writer""" + """Make the standard CSV file writer.""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/generators.py b/datafaker/generators.py index aa0b20b2..ee0add2b 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Any, Callable, Iterable, MutableSequence, Sequence, TypeVar, Union +from typing import Any, Callable, Iterable, Sequence, Union import mimesis import mimesis.locales @@ -21,7 +21,7 @@ from typing_extensions import Self from datafaker.base import DistributionGenerator -from datafaker.utils import logger +from datafaker.utils import logger, T numeric = Union[int, float] @@ -1670,13 +1670,14 @@ def _actual_kwargs_with_combinations( "name": "constant", "params": {"value": [None] * len(partition.excluded_columns)}, } - if not partition.excluded_columns: + covariates = { + "covs": partition.covariates, + } + if not partition.constant_outputs: return { "count": count, "name": self._function_name, - "params": { - "covs": partition.covariates, - }, + "params": covariates, } return { "count": count, @@ -1684,9 +1685,7 @@ def _actual_kwargs_with_combinations( "params": { "constants_at": partition.constant_outputs, "subgen": self._function_name, - "params": { - "covs": partition.covariates, - }, + "params": covariates, }, } @@ -1718,9 +1717,6 @@ def is_numeric(col: Column) -> bool: return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys -T = TypeVar("T") - - def powerset(input: list[T]) -> Iterable[Iterable[T]]: """Returns a list of all sublists of""" return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) @@ -1780,7 +1776,7 @@ class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): SUPPRESS_COUNT = 5 EMPTY_RESULT = [ RowMapping( - parent=sqlalchemy.engine.result.ResultMetaData(), + parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), processors=None, key_to_index={"count": 0}, data=(0,), @@ -1942,10 +1938,11 @@ def _execute_partition_queries( """ found_nonzero = False for rp in partitions.values(): - rp.covariates = connection.execute(text(rp.query)).mappings().fetchall() - if not rp.covariates or rp.covariates[0]["count"] is None: + covs = connection.execute(text(rp.query)).mappings().fetchall() + if not covs or covs.count == 0 or covs[0]["count"] is None: rp.covariates = self.EMPTY_RESULT else: + rp.covariates = covs found_nonzero = True return found_nonzero diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 47f918cf..580fa3af 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -1673,14 +1673,16 @@ def _get_custom_queries_from( ) -> None: if type(nominal) is str: src_stat_groups = self.SRC_STAT_RE.search(nominal) + # Do we have a SRC_STAT reference? if src_stat_groups: + # Get its name cq_key = src_stat_groups.group(1) - if cq_key not in out: - out[cq_key] = [] + # Are we pulling a specific part of this result? sub = src_stat_groups.group(3) if sub: actual = {sub: actual} - out[cq_key].append(actual) + else: + out[cq_key] = actual elif type(nominal) is list and type(actual) is list: for i in range(min(len(nominal), len(actual))): self._get_custom_queries_from(out, nominal[i], actual[i]) @@ -1780,10 +1782,11 @@ def do_propose(self, _arg: str) -> None: ) def do_p(self, arg: str) -> None: - """Synonym for propose""" + """Synonym for propose.""" self.do_propose(arg) def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: + """Find a generator by name from the list of proposals.""" for gen in self._get_generator_proposals(): if gen.name() == gen_name: return gen @@ -1792,7 +1795,7 @@ def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: def do_set(self, arg: str) -> None: """ Set one of the proposals as a generator. - Takes a single integer argument. + :param arg: A single integer (as a string). """ if arg.isdigit() and not self._generators_valid(): self.print("Please run 'propose' before 'set '") @@ -1820,9 +1823,7 @@ def do_set(self, arg: str) -> None: self._go_next() def set_generator(self, gen: Generator | None) -> None: - """ - Set the current column's generator. - """ + """Set the current column's generator.""" (table, gen_info) = self.get_table_and_generator() if table is None: self.print("Error: no table") @@ -1833,18 +1834,21 @@ def set_generator(self, gen: Generator | None) -> None: gen_info.gen = gen def do_s(self, arg: str) -> None: - """Synonym for set""" + """Synonym for set.""" self.do_set(arg) def do_unset(self, _arg: str) -> None: - """ - Remove any generator set for this column. - """ + """Remove any generator set for this column.""" self.set_generator(None) self._go_next() def do_merge(self, arg: str) -> None: - """Add this column(s) to the specified column(s), so one generator covers them all.""" + """ + Add this column(s) to the specified column(s). + + After this, one generator will cover them all. + :param arg: space separated list of column names to merge. + """ cols = arg.split() if not cols: self.print("Error: merge requires a column argument") @@ -1899,6 +1903,7 @@ def do_merge(self, arg: str) -> None: def complete_merge( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Complete column names.""" last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: @@ -1954,6 +1959,7 @@ def do_unmerge(self, arg: str) -> None: def complete_unmerge( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Complete column names to unmerge.""" last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: @@ -1969,7 +1975,7 @@ def update_config_generators( src_dsn: str, src_schema: str | None, metadata: MetaData, - config: Mapping[str, Any], + config: MutableMapping[str, Any], spec_path: Path | None, ) -> Mapping[str, Any]: """ @@ -1981,7 +1987,7 @@ def update_config_generators( :param src_dsn: Address of the source database :param src_schema: Name of the source database schema to read from :param metadata: SQLAlchemy representation of the source database - :param config: Existing configuration (will not be destructively updated) + :param config: Existing configuration (will be destructively updated) :param spec_path: The path of the CSV file containing the specification :return: Updated configuration. """ diff --git a/datafaker/main.py b/datafaker/main.py index 65b0102a..454cf446 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,16 +1,17 @@ """Entrypoint for the datafaker package.""" import asyncio +import importlib import io import json import sys from enum import Enum -from importlib import metadata from pathlib import Path from typing import Any, Final, Optional import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate +from sqlalchemy import MetaData from typer import Argument, Exit, Option, Typer from datafaker.create import create_db_data, create_db_tables, create_db_vocab @@ -81,7 +82,13 @@ def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: return meta_dict -def load_metadata(orm_file_name: str, config: dict | None = None) -> Any: +def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: + """ + Load metadata from ``orm.yaml`` + :param orm_file_name: ``orm.yaml`` or alternative name to load metadata from. + :param config: Used to exclude tables that are marked as ``ignore: true``. + :return: SQLAlchemy MetaData object representing the database described by the loaded file. + """ meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) @@ -548,16 +555,16 @@ def remove_tables( class TableType(str, Enum): - all = "all" - vocab = "vocab" - generated = "generated" + ALL = "all" + VOCAB = "vocab" + GENERATED = "generated" @app.command() def list_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - tables: TableType = Option(TableType.generated, help="Which tables to list"), + tables: TableType = Option(TableType.GENERATED, help="Which tables to list"), ) -> None: """List the names of tables described in the metadata file.""" config = read_config_file(config_file) if config_file is not None else {} @@ -568,9 +575,9 @@ def list_tables( for (table_name, table_config) in config.get("tables", {}).items() if get_flag(table_config, "vocabulary_table") } - if tables == TableType.all: + if tables == TableType.ALL: names = all_table_names - elif tables == TableType.generated: + elif tables == TableType.GENERATED: names = all_table_names - vocab_table_names else: names = vocab_table_names @@ -584,7 +591,7 @@ def version() -> None: logger.info( "%s version %s", __package__, - metadata.version(__package__), + importlib.metadata.version(__package__), ) diff --git a/datafaker/make.py b/datafaker/make.py index ac7cdfce..e9db8636 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -19,7 +19,7 @@ from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table -from sqlalchemy.sql import Executable, sqltypes, type_api +from sqlalchemy.sql import Executable, sqltypes from typing_extensions import Self from datafaker import providers @@ -825,6 +825,12 @@ async def make_src_stats( async def make_src_stats_connection( config: Mapping, db_conn: DbConnection, metadata: MetaData ) -> dict[str, dict[str, Any]]: + """ + Make the ``src-stats.yaml`` file given the database connection to read from. + :param config: configuration from ``config.yaml``. + :param db_conn: Source database connection. + :param metadata: Source database metadata from ``orm.yaml``. + """ date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/providers.py b/datafaker/providers.py index 03e6cbe2..65abf069 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -30,6 +30,7 @@ def column_value( return None def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: + """Initialise the column value provider.""" super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index d407d494..0b96e346 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,5 +1,6 @@ +"""Convert between a Python dict describing a database schema and a SQLAlchemy MetaData.""" import typing -from typing import Callable, Protocol +from typing import Callable import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table @@ -8,7 +9,7 @@ from datafaker.utils import make_foreign_key_name -table_t = dict[str, typing.Any] +TableT = dict[str, typing.Any] # We will change this to parsy.Parser when parsy exports its types properly @@ -17,7 +18,8 @@ def simple(type_: type) -> ParserType: """ - Parses a simple sqltypes type. + Get a parser for a simple sqltypes type. + For example, simple(sqltypes.UUID) takes the string "UUID" and outputs a UUID class, or fails with any other string. """ @@ -26,14 +28,15 @@ def simple(type_: type) -> ParserType: def integer() -> ParserType: """ - Parses an integer, outputting that integer. + Get a parser for an integer, outputting that integer. """ return parsy.regex(r"-?[0-9]+").map(int) def integer_arguments() -> ParserType: """ - Parses a list of integers. + Get a parser for a list of integers. + The integers are surrounded by brackets and separated by a comma and space. """ @@ -44,6 +47,8 @@ def integer_arguments() -> ParserType: def numeric_type(type_: type) -> ParserType: """ + Make a parser for a SQL numeric type. + Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. """ @@ -53,12 +58,16 @@ def numeric_type(type_: type) -> ParserType: def string_type(type_: type) -> ParserType: + """ + Make a parser for a SQL string type. + + Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" + or TYPE_NAME(32) COLLATE "fr" + """ + @parsy.generate(type_.__name__) def st_parser() -> typing.Generator[ParserType, None, typing.Any]: - """ - Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" - or TYPE_NAME(32) COLLATE "fr" - """ + """Parse the specific type.""" yield parsy.string(type_.__name__) length: int | None = yield ( parsy.string("(") >> integer() << parsy.string(")") @@ -72,12 +81,22 @@ def st_parser() -> typing.Generator[ParserType, None, typing.Any]: def time_type(type_: type, pg_type: type) -> ParserType: + """ + Make a parser for a SQL date/time type. + + Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE + or TYPE_NAME(32) WITH TIME ZONE + + :param type_: The SQLAlchemy type we would like to parse. + :param pg_type: The PostgreSQL type we would like to parse if precision + or timezone is provided. + :return: ``type_`` if neither precision nor timezone are provided in the + parsed text, ``pg_type(precision, timezone)`` otherwise. + """ + @parsy.generate(type_.__name__) def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: - """ - Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE - or TYPE_NAME(32) WITH TIME ZONE - """ + """Parse the actual type.""" yield parsy.string(type_.__name__) precision: int | None = yield ( parsy.string("(") >> integer() << parsy.string(")") @@ -130,6 +149,11 @@ def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: @parsy.generate def type_parser() -> ParserType: + """ + Make a parser for a simple type or an array. + + Arrays produce a PostgreSQL-specific type. + """ base = yield SIMPLE_TYPE_PARSER dimensions = yield parsy.string("[]").many().map(len) if dimensions == 0: @@ -138,6 +162,11 @@ def type_parser() -> ParserType: def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: + """ + Produce a dict description of a column. + :param column: The SQLAlchemy column to translate. + :param dialect: The SQL dialect in which to render the type name. + """ type_ = column.type if isinstance(type_, postgresql.DOMAIN): # Instead of creating a restricted type, we'll just use the base type. @@ -165,6 +194,20 @@ def dict_to_column( rep: dict, ignore_fk: Callable[[str], bool], ) -> Column: + """ + Produce column from aspects of its dict description. + :param table_name: The name of the table the column appears in. + :param col_name: The name of the column. + :param rep: The dict description of the column. + :ignore_fk: A predicate, called with the name of any foreign key target + (in other words, the name of any table referred to by this column). If it + returns True, this foreign key constraint will not be applied to the + returned column. This is useful in a situation where we want a foreign + key constraint to be present when we are determining what generators + might be appropriate for it, but we don't want the foreign key constraint + actually applied to the destination database because (for example) the + target table will be ignored. + """ type_sql = rep["type"] try: type_ = type_parser.parse(type_sql) @@ -193,21 +236,20 @@ def dict_to_column( def dict_to_unique(rep: dict) -> schema.UniqueConstraint: + """Make a uniqueness constraint from its dict representation.""" return schema.UniqueConstraint(*rep.get("columns", []), name=rep.get("name", None)) def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: + """Render a dict representation of a uniqueness constraint.""" return { "name": constraint.name, "columns": [str(col.name) for col in constraint.columns], } -def table_to_dict(table: Table, dialect: Dialect) -> table_t: - """ - Converts a SQL Alchemy Table object into a - Python object ready for conversion to YAML. - """ +def table_to_dict(table: Table, dialect: Dialect) -> TableT: + """Converts a SQL Alchemy Table object into a Python dict.""" return { "columns": { str(column.key): column_to_dict(column, dialect) @@ -224,9 +266,10 @@ def table_to_dict(table: Table, dialect: Dialect) -> table_t: def dict_to_table( name: str, meta: MetaData, - table_dict: table_t, + table_dict: TableT, ignore_fk: Callable[[str], bool], ) -> Table: + """Create a Table from its description.""" return Table( name, meta, @@ -255,7 +298,7 @@ def metadata_to_dict( } -def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]) -> bool: +def should_ignore_fk(fk: str, tables_dict: dict[str, TableT]) -> bool: """ Tell if this foreign key should be ignored because it points to an ignored table. diff --git a/mypy.ini b/mypy.ini index 86ff2fb3..c2ea784f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ # Global options: [mypy] -python_version = 3.9 +python_version = 3.10 disallow_untyped_defs = True disallow_any_unimported = True no_implicit_optional = True diff --git a/tests/test_functional.py b/tests/test_functional.py index 394691b3..e60baa1e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -117,8 +117,18 @@ def test_workflow_minimal_args(self) -> None: self.assertNoException(completed_process) self.assertEqual( { - "Unsupported SQLAlchemy type CIDR for column column_with_unusual_type. Setting this column to NULL always, you may want to configure a row generator for it instead.", - "Unsupported SQLAlchemy type BIT for column column_with_unusual_type_and_length. Setting this column to NULL always, you may want to configure a row generator for it instead.", + ( + "Unsupported SQLAlchemy type CIDR for column " + "column_with_unusual_type. Setting this column to NULL " + "always, you may want to configure a row generator for " + "it instead." + ), + ( + "Unsupported SQLAlchemy type BIT for column " + "column_with_unusual_type_and_length. Setting this column " + "to NULL always, you may want to configure a row generator " + "for it instead." + ), }, set(completed_process.stderr.split("\n")) - {""}, ) @@ -309,7 +319,10 @@ def test_workflow_maximal_args(self) -> None: self.assertSetEqual( { "Dropping constraint concept_concept_type_id_fkey from table concept", - "Dropping constraint ref_to_unignorable_table_ref_fkey from table ref_to_unignorable_table", + ( + "Dropping constraint ref_to_unignorable_table_ref_fkey from " + "table ref_to_unignorable_table" + ), "Dropping constraint concept_type_mitigation_type_id_fkey from table concept_type", "Restoring foreign key constraint concept_concept_type_id_fkey", "Restoring foreign key constraint ref_to_unignorable_table_ref_fkey", @@ -408,8 +421,14 @@ def test_workflow_maximal_args(self) -> None: 'Truncating vocabulary table "mitigation_type".', 'Truncating vocabulary table "empty_vocabulary".', "Vocabulary tables truncated.", - "Dropping constraint concept_type_mitigation_type_id_fkey from table concept_type", - "Dropping constraint ref_to_unignorable_table_ref_fkey from table ref_to_unignorable_table", + ( + "Dropping constraint concept_type_mitigation_type_id_fkey " + "from table concept_type" + ), + ( + "Dropping constraint ref_to_unignorable_table_ref_fkey from " + "table ref_to_unignorable_table" + ), "Dropping constraint concept_concept_type_id_fkey from table concept", "Restoring foreign key constraint concept_type_mitigation_type_id_fkey", "Restoring foreign key constraint ref_to_unignorable_table_ref_fkey", diff --git a/tests/test_interactive.py b/tests/test_interactive.py index e1de3b3a..a7803f01 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -286,7 +286,7 @@ def test_list_tables(self) -> None: person_listed = False unique_constraint_test_listed = False no_pk_test_listed = False - for text, args, kwargs in tc.messages: + for _text, args, _kwargs in tc.messages: if args[2] == "person": self.assertFalse(person_listed) person_listed = True @@ -477,13 +477,13 @@ def test_null_configuration(self) -> None: "tables": None, } with self._get_cmd(config) as gc: - TABLE = "model" - gc.do_next(f"{TABLE}.name") + table = "model" + gc.do_next(f"{table}.name") gc.do_propose("") gc.do_compare("") gc.do_set("1") gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) def test_null_table_configuration(self) -> None: """Test that a table having null configuration does not break.""" @@ -493,12 +493,12 @@ def test_null_table_configuration(self) -> None: } } with self._get_cmd(config) as gc: - TABLE = "model" - gc.do_next(f"{TABLE}.name") + table = "model" + gc.do_next(f"{table}.name") gc.do_propose("") gc.do_set("1") gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) def test_prompts(self) -> None: """Test that the prompts follow the names of the columns and assigned generators.""" @@ -543,94 +543,94 @@ def test_prompts(self) -> None: def test_set_generator_mimesis(self) -> None: """Test that we can set one generator to a mimesis generator.""" with self._get_cmd({}) as gc: - TABLE = "model" - COLUMN = "name" - GENERATOR = "person.first_name" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "model" + column = "name" + generator = "person.first_name" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"generic.{GENERATOR}"][0])) + gc.do_set(str(proposals[f"generic.{generator}"][0])) gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) self.assertDictEqual( - gc.config["tables"][TABLE]["row_generators"][0], - {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, + gc.config["tables"][table]["row_generators"][0], + {"name": f"generic.{generator}", "columns_assigned": [column]}, ) def test_set_generator_distribution(self) -> None: """Test that we can set one generator to gaussian.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.normal" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "frequency" + generator = "dist_gen.normal" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) + gc.do_set(str(proposals[generator][0])) gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column]) self.assertDictEqual( row_gen["kwargs"], { - "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', + "mean": f'SRC_STATS["auto__{table}"]["results"][0]["mean__{column}"]', + "sd": f'SRC_STATS["auto__{table}"]["results"][0]["stddev__{column}"]', }, ) self.assertEqual(len(gc.config["src-stats"]), 1) self.assertSetEqual( set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") + self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", ) def test_set_generator_distribution_directly(self) -> None: """Test that we can set one generator to gaussian without going through propose.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.normal" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "frequency" + generator = "dist_gen.normal" + gc.do_next(f"{table}.{column}") gc.reset() - gc.do_set(GENERATOR) + gc.do_set(generator) self.assertListEqual(gc.messages, []) gc.do_quit("") self.assertEqual(len(gc.config["src-stats"]), 1) self.assertSetEqual( set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") + self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", ) def test_set_generator_choice(self) -> None: """Test that we can set one generator to uniform choice.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.choice" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "frequency" + generator = "dist_gen.choice" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) + gc.do_set(str(proposals[generator][0])) gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column]) self.assertDictEqual( row_gen["kwargs"], { - "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', + "a": f'SRC_STATS["auto__{table}__{column}"]["results"]', }, ) self.assertEqual(len(gc.config["src-stats"]), 1) @@ -638,108 +638,108 @@ def test_set_generator_choice(self) -> None: set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) self.assertEqual( - gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}" + gc.config["src-stats"][0]["name"], f"auto__{table}__{column}" ) self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", + f"SELECT {column} AS value FROM {table} WHERE {column} IS NOT NULL GROUP BY value ORDER BY COUNT({column}) DESC", ) def test_weighted_choice_generator_generates_choices(self) -> None: """Test that propose and compare show weighted_choice's values.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "position" - GENERATOR = "dist_gen.weighted_choice" - VALUES = {1, 2, 3, 4, 5, 6} - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "position" + generator = "dist_gen.weighted_choice" + values = {1, 2, 3, 4, 5, 6} + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gen_proposal = proposals[GENERATOR] - self.assertSubset(set(gen_proposal[2]), {str(v) for v in VALUES}) + gen_proposal = proposals[generator] + self.assertSubset(set(gen_proposal[2]), {str(v) for v in values}) gc.do_compare(str(gen_proposal[0])) - col_heading = f"{gen_proposal[0]}. {GENERATOR}" + col_heading = f"{gen_proposal[0]}. {generator}" self.assertIn(col_heading, gc.columns) - self.assertSubset(set(gc.columns[col_heading]), VALUES) + self.assertSubset(set(gc.columns[col_heading]), values) def test_merge_columns(self) -> None: """Test that we can merge columns and set a multivariate generator""" - TABLE = "string" - COLUMN_1 = "frequency" - COLUMN_2 = "position" - GENERATOR_TO_DISCARD = "dist_gen.choice" - GENERATOR = "dist_gen.multivariate_normal" + table = "string" + column_1 = "frequency" + column_2 = "position" + generator_to_discard = "dist_gen.choice" + generator = "dist_gen.multivariate_normal" with self._get_cmd({}) as gc: - gc.do_next(f"{TABLE}.{COLUMN_2}") + gc.do_next(f"{table}.{column_2}") gc.do_propose("") proposals = gc.get_proposals() # set a generator, but this should not exist after merging - gc.do_set(str(proposals[GENERATOR_TO_DISCARD][0])) - gc.do_next(f"{TABLE}.{COLUMN_1}") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) + gc.do_set(str(proposals[generator_to_discard][0])) + gc.do_next(f"{table}.{column_1}") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) gc.do_propose("") proposals = gc.get_proposals() # set a generator, but this should not exist either - gc.do_set(str(proposals[GENERATOR_TO_DISCARD][0])) + gc.do_set(str(proposals[generator_to_discard][0])) gc.do_previous("") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) - gc.do_merge(COLUMN_2) - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) + gc.do_merge(column_2) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) gc.reset() gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) + gc.do_set(str(proposals[generator][0])) gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column_1, column_2]) def test_unmerge_columns(self) -> None: """Test that we can unmerge columns and generators are removed""" - TABLE = "string" - COLUMN_1 = "frequency" - COLUMN_2 = "position" - COLUMN_3 = "model_id" - REMAINING_GEN = "gen3" + table = "string" + column_1 = "frequency" + column_2 = "position" + column_3 = "model_id" + remaining_gen = "gen3" config = { "tables": { - TABLE: { + table: { "row_generators": [ - {"name": "gen1", "columns_assigned": [COLUMN_1, COLUMN_2]}, - {"name": REMAINING_GEN, "columns_assigned": [COLUMN_3]}, + {"name": "gen1", "columns_assigned": [column_1, column_2]}, + {"name": remaining_gen, "columns_assigned": [column_3]}, ] } } } with self._get_cmd(config) as gc: - gc.do_next(f"{TABLE}.{COLUMN_2}") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) - gc.do_unmerge(COLUMN_1) - self.assertIn(TABLE, gc.prompt) - self.assertNotIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) + gc.do_next(f"{table}.{column_2}") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) + gc.do_unmerge(column_1) + self.assertIn(table, gc.prompt) + self.assertNotIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) # Next generator should be the unmerged one gc.do_next("") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) gc.do_quit("") # Both generators should have disappeared - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], REMAINING_GEN) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) + self.assertEqual(row_gen["name"], remaining_gen) + self.assertListEqual(row_gen["columns_assigned"], [column_3]) def test_old_generators_remain(self) -> None: """Test that we can set one generator and keep an old one.""" @@ -766,18 +766,18 @@ def test_old_generators_remain(self) -> None: ], } with self._get_cmd(config) as gc: - TABLE = "model" - COLUMN = "name" - GENERATOR = "person.first_name" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "model" + column = "name" + generator = "person.first_name" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"generic.{GENERATOR}"][0])) + gc.do_set(str(proposals[f"generic.{generator}"][0])) gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) self.assertDictEqual( - gc.config["tables"][TABLE]["row_generators"][0], - {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, + gc.config["tables"][table]["row_generators"][0], + {"name": f"generic.{generator}", "columns_assigned": [column]}, ) row_gens = gc.config["tables"]["string"]["row_generators"] self.assertEqual(len(row_gens), 1) @@ -795,7 +795,7 @@ def test_old_generators_remain(self) -> None: self.assertSetEqual( set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__string") + self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") self.assertEqual( gc.config["src-stats"][0]["query"], "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", @@ -829,31 +829,31 @@ def test_aggregate_queries_merge(self) -> None: ], } with self._get_cmd(copy.deepcopy(config)) as gc: - COLUMN = "position" - GENERATOR = "dist_gen.uniform_ms" - gc.do_next(f"string.{COLUMN}") + column = "position" + generator = "dist_gen.uniform_ms" + gc.do_next(f"string.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"{GENERATOR}"][0])) + gc.do_set(str(proposals[f"{generator}"][0])) gc.do_quit("") row_gens: list[dict[str, Any]] = gc.config["tables"]["string"][ "row_generators" ] self.assertEqual(len(row_gens), 2) - if row_gens[0]["name"] == GENERATOR: + if row_gens[0]["name"] == generator: row_gen0 = row_gens[0] row_gen1 = row_gens[1] else: row_gen0 = row_gens[1] row_gen1 = row_gens[0] - self.assertEqual(row_gen0["name"], GENERATOR) + self.assertEqual(row_gen0["name"], generator) self.assertEqual(row_gen1["name"], "dist_gen.normal") - self.assertListEqual(row_gen0["columns_assigned"], [COLUMN]) + self.assertListEqual(row_gen0["columns_assigned"], [column]) self.assertDictEqual( row_gen0["kwargs"], { - "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', + "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{column}"]', + "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{column}"]', }, ) self.assertListEqual(row_gen1["columns_assigned"], ["frequency"]) @@ -877,8 +877,8 @@ def test_aggregate_queries_merge(self) -> None: { "AVG(frequency) AS mean__frequency", "STDDEV(frequency) AS stddev__frequency", - f"AVG({COLUMN}) AS mean__{COLUMN}", - f"STDDEV({COLUMN}) AS stddev__{COLUMN}", + f"AVG({column}) AS mean__{column}", + f"STDDEV({column}) AS stddev__{column}", }, ) @@ -954,12 +954,12 @@ def test_existing_configuration_remains(self) -> None: ], } with self._get_cmd(config) as gc: - COLUMN = "position" - GENERATOR = "dist_gen.uniform_ms" - gc.do_next(f"string.{COLUMN}") + column = "position" + generator = "dist_gen.uniform_ms" + gc.do_next(f"string.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"{GENERATOR}"][0])) + gc.do_set(str(proposals[f"{generator}"][0])) gc.do_quit("") src_stats = {stat["name"]: stat["query"] for stat in gc.config["src-stats"]} self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) @@ -1170,8 +1170,8 @@ def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: def test_set_missingness_to_sampled(self) -> None: """Test that we can set one table to sampled missingness.""" with self._get_cmd({}) as mc: - TABLE = "signature_model" - mc.do_next(TABLE) + table = "signature_model" + mc.do_next(table) mc.do_counts("") self.assertListEqual( mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] @@ -1181,7 +1181,7 @@ def test_set_missingness_to_sampled(self) -> None: mc.do_sampled("") mc.do_quit("") self.assertListEqual( - mc.config["tables"][TABLE]["missingness_generators"], + mc.config["tables"][table]["missingness_generators"], [ { "columns": ["player_id", "based_on"], @@ -1199,9 +1199,12 @@ def test_set_missingness_to_sampled(self) -> None: self.assertEqual( mc.config["src-stats"][0]["query"], ( - "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" + "SELECT COUNT(*) AS row_count," + " player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null," + " based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000)" + " AS __t GROUP BY player_id__is_null, based_on__is_null" ), ) @@ -1325,7 +1328,7 @@ def test_dist_gen_sampled_produces_ordered_src_stats(self) -> None: ] self.assertListEqual(based_ons, [1, 3, 2]) - def assertAreTruncatedTo(self, xs: Iterable[str], length: int) -> None: + def assert_are_truncated_to(self, xs: Iterable[str], length: int) -> None: """ Check that none of the strings are longer than ``length`` (after removing surrounding quotes). @@ -1339,62 +1342,71 @@ def assertAreTruncatedTo(self, xs: Iterable[str], length: int) -> None: def test_varchar_ns_are_truncated(self) -> None: """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" - GENERATOR = "generic.text.quote" - TABLE = "signature_model" - COLUMN = "name" + generator = "generic.text.quote" + table = "signature_model" + column = "name" with self._get_cmd({}) as gc: - gc.do_next(f"{TABLE}.{COLUMN}") + gc.do_next(f"{table}.{column}") gc.reset() gc.do_propose("") proposals = gc.get_proposals() - quotes = [k for k in proposals.keys() if k.startswith(GENERATOR)] + quotes = [k for k in proposals.keys() if k.startswith(generator)] self.assertEqual(len(quotes), 1) prop = proposals[quotes[0]] - self.assertAreTruncatedTo(prop[2], 20) + self.assert_are_truncated_to(prop[2], 20) gc.reset() gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. {quotes[0]}" gc.do_set(str(prop[0])) self.assertIn(col_heading, gc.columns) - self.assertAreTruncatedTo(gc.columns[col_heading], 20) + self.assert_are_truncated_to(gc.columns[col_heading], 20) gc.do_quit("") config = gc.config self.generate_data(config, num_passes=15) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[TABLE].c[COLUMN]) + stmt = select(self.metadata.tables[table].c[column]) rows = conn.execute(stmt).scalars().fetchall() - self.assertAreTruncatedTo(rows, 20) + self.assert_are_truncated_to(rows, 20) @dataclass class Stat: + """Mean and variance calculator.""" + n: int = 0 x: float = 0 x2: float = 0 def add(self, x: float) -> None: + """Add one datum.""" self.n += 1 self.x += x self.x2 += x * x def count(self) -> int: + """Get the number of data added.""" return self.n def x_mean(self) -> float: + """Get the mean of the added data.""" return self.x / self.n def x_var(self) -> float: + """Get the variance of the added data.""" x = self.x return (self.x2 - x * x / self.n) / (self.n - 1) @dataclass class Correlation(Stat): + """Mean, variance and covariance.""" + y: float = 0 y2: float = 0 xy: float = 0 def add2(self, x: float, y: float) -> None: + """Add a 2D data point.""" self.n += 1 self.x += x self.x2 += x * x @@ -1403,13 +1415,16 @@ def add2(self, x: float, y: float) -> None: self.xy += x * y def y_mean(self) -> float: + """Get the mean of the second parts of the added points.""" return self.y / self.n def y_var(self) -> float: + """Get the variance of the second parts of the added points.""" y = self.y return (self.y2 - y * y / self.n) / (self.n - 1) def covar(self) -> float: + """Get the covariance of the two parts of the added points.""" return (self.xy - self.x * self.y / self.n) / (self.n - 1) @@ -1702,7 +1717,7 @@ def test_non_interactive_configure_generators( """ test that we can set generators from a CSV file """ - config: Mapping[str, Any] = {} + config: MutableMapping[str, Any] = {} spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( self.dsn, self.schema_name, self.metadata, config, spec_csv diff --git a/tests/test_utils.py b/tests/test_utils.py index b34e9227..54c167a2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -288,6 +288,7 @@ def test_generators_require_stats(self) -> None: @patch("datafaker.utils.logger") def test_testing_generators_finds_syntax_errors(self, logger: MagicMock) -> None: + """Test that looking for ``SRC_STATS`` references finds Python syntax errors.""" generators_require_stats( { "story_generators": [ From 42fb24a0420b5f2bd83ee5996ebdec706a2fdc0b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 9 Oct 2025 18:58:11 +0100 Subject: [PATCH 17/44] Many, many cleanups. --- CONTRIBUTING.md | 3 +- datafaker/base.py | 107 +++++++++-- datafaker/dump.py | 4 +- datafaker/generators.py | 299 +++++++++++++++++++---------- datafaker/interactive.py | 327 +++++++++++++++++++++++--------- datafaker/main.py | 48 +++-- datafaker/make.py | 76 +++++--- datafaker/providers.py | 2 +- datafaker/serialize_metadata.py | 39 ++-- datafaker/utils.py | 109 +++++++++-- tests/test_dump.py | 6 +- tests/test_functional.py | 85 ++++++--- tests/test_interactive.py | 37 +++- tests/test_main.py | 2 +- tests/test_remove.py | 4 + tests/test_rst.py | 16 +- 16 files changed, 836 insertions(+), 328 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 259ebe8e..8d7b8799 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,7 +55,8 @@ These tests do not currently work, and will be replaced by unit tests. Functional tests require PostgreSQL to be installed. - *WARNING: Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file.* +..warning:: + Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file. ## Building documentation locally diff --git a/datafaker/base.py b/datafaker/base.py index 8270ccb9..4ceb2aff 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -27,6 +27,7 @@ @functools.cache def zipf_weights(size: int) -> list[float]: + """Get the weights of a Zipf distribution of a given size.""" total = sum(map(lambda n: 1 / n, range(1, size + 1))) return [1 / (n * total) for n in range(1, size + 1)] @@ -36,6 +37,7 @@ def merge_with_constants( ) -> Generator[T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. + :param constants_at: A map of indices to objects that must be placed at those indices. :param xs: Items that fill in the gaps left by ``constants_at``. @@ -62,35 +64,88 @@ def merge_with_constants( class NothingToGenerateException(Exception): + """Exception thrown when no value can be generated.""" + def __init__(self, message: str): + """Initialise the exception with a human-readable message.""" super().__init__(message) class DistributionGenerator: + """An object that can produce values from various distributions.""" + root3 = math.sqrt(3) def __init__(self) -> None: + """Initialise the DistributionGenerator.""" self.np_gen = np.random.default_rng() def uniform(self, low: float, high: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param low: The lowest value that can be chosen. + :param high: The highest value that can be chosen. + :return: The output value. + """ return random.uniform(float(low), float(high)) def uniform_ms(self, mean: float, sd: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ m = float(mean) h = self.root3 * float(sd) return random.uniform(m - h, m + h) def normal(self, mean: float, sd: float) -> float: + """ + Choose a value according to a Gaussian (normal) distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ return random.normalvariate(float(mean), float(sd)) def lognormal(self, logmean: float, logsd: float) -> float: + """ + Choose a value according to a lognormal distribution. + + :param logmean: The mean of the logs of the output values. + :param logsd: The standard deviation of the logs of the output values. + :return: The output value. + """ return random.lognormvariate(float(logmean), float(logsd)) def choice(self, a: list[T]) -> T: + """ + Choose a value with equal probability. + + :param a: The list of values to output. Each element is either + the value itself, or a mapping with a key ``value`` and the key + is the value to return. + :return: The chosen value. + """ c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c def zipf_choice(self, a: list[T], n: int | None = None) -> T: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of values to output, most frequent first. + Each element is either the value itself, or a mapping with + a key ``value`` and the key is the value to return. + :return: The chosen value. + """ if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] @@ -99,9 +154,11 @@ def zipf_choice(self, a: list[T], n: int | None = None) -> T: def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ Choice weighted by the count in the original dataset. + :param a: a list of dicts, each with a ``value`` key holding the value to be returned and a ``count`` key holding the number of that value found in the original dataset + :return: The chosen ``value``. """ vs = [] counts = [] @@ -114,9 +171,19 @@ def weighted_choice(self, a: list[dict[str, Any]]) -> Any: return c def constant(self, value: T) -> T: + """Return the same value always.""" return value def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: + """ + Return an array of values chosen from the given covariates. + + :param cov: Keys are ``rank``: The number of values to output; + ``mN``: The mean of variable ``N`` (where ``N`` is between 0 and + one less than ``rank``). ``cN_M`` (where 0 < ``N`` <= ``M`` < ``rank``): + the covariance between the ``N``th and the ``M``th variables. + :return: A numpy array of results. + """ rank = int(cov["rank"]) if rank == 0: return np.empty(shape=(0,)) @@ -131,9 +198,7 @@ def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: return self.np_gen.multivariate_normal(mean, covs) def _select_group(self, alts: list[dict[str, Any]]) -> Any: - """ - Choose one of the ``alts`` weighted by their ``"count"`` elements. - """ + """Choose one of the ``alts`` weighted by their ``"count"`` elements.""" total = 0 for alt in alts: if alt["count"] < 0: @@ -204,9 +269,7 @@ def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: return out def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: - """ - Produce a list of values pulled from a set of multivariate distributions. - """ + """Produce a list of values pulled from a set of multivariate distributions.""" cov = self._select_group(covs) logger.debug("Multivariate normal group selected: %s", cov) constants = self._find_constants(cov) @@ -214,9 +277,7 @@ def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: return list(merge_with_constants(nums, constants)) def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: - """ - Produce a list of values pulled from a set of multivariate distributions. - """ + """Produce a list of values pulled from a set of multivariate distributions.""" cov = self._select_group(covs) logger.debug("Multivariate lognormal group selected: %s", cov) constants = self._find_constants(cov) @@ -233,7 +294,7 @@ def alternatives( counts: list[dict[str, int]] | None, ) -> Any: """ - A generator that picks between other generators. + Pick between other generators. :param alternative_configs: List of alternative generators. Each alternative has the following keys: "count" -- a weight for @@ -264,6 +325,17 @@ def alternatives( def with_constants_at( self, constants_at: dict[int, T], subgen: str, params: dict[str, T] ) -> list[T]: + """ + Insert constants into the results of a different generator. + + :param constants_at: A dictionary of positions and objects to insert + into the return list at those positions. + :param subgen: The name of the function to call to get the results + that will have the constants inserted into. + :param params: Keyword arguments to the ``subgen`` function. + :return: A list of results from calling ``subgen(**params)`` + with ``constants_at`` inserted in at the appropriate indices. + """ if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -277,7 +349,7 @@ def with_constants_at( def truncated_string( self, subgen_fn: Callable[..., list[T]], params: dict, length: int ) -> list[T]: - """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" + """Call ``subgen_fn(**params)`` and truncate the results to ``length``.""" result = subgen_fn(**params) if result is None: return None @@ -358,7 +430,18 @@ def load(self, connection: Connection, base_path: Path = Path(".")) -> None: class ColumnPresence: - def sampled(self, patterns: list[dict[str, Any]]) -> set[Any]: + """Object for generators to use for missingness completely at random.""" + + def sampled(self, patterns: list[dict[str, Any]]) -> set[str]: + """ + Select a random pattern and output the non-null columns. + + :param patterns: List of outputs from missingness SQL queries. + Columns in each output: ``row_count`` is the number of rows + with this missingness pattern, then for each column + ```` there is a boolean called ``missingness__is_null``. + :return: All the names of the columns no make non-null. + """ total = 0 for pattern in patterns: total += pattern.get("row_count", 0) diff --git a/datafaker/dump.py b/datafaker/dump.py index 95f251a7..2307ba41 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,4 +1,4 @@ -""" Data dumping functions. """ +"""Data dumping functions.""" import csv import io from typing import TYPE_CHECKING @@ -35,4 +35,4 @@ def dump_db_tables( with engine.connect() as connection: result = connection.execute(sqlalchemy.select(table)) for row in result: - csv_out.writerow(row._tuple()) + csv_out.writerow(row) diff --git a/datafaker/generators.py b/datafaker/generators.py index ee0add2b..4e421af1 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -1,6 +1,4 @@ -""" -Generator factories for making generators for single columns. -""" +"""Generator factories for making generators for single columns.""" import decimal import math @@ -21,7 +19,7 @@ from typing_extensions import Self from datafaker.base import DistributionGenerator -from datafaker.utils import logger, T +from datafaker.utils import T, logger numeric = Union[int, float] @@ -49,11 +47,11 @@ class Generator(ABC): @abstractmethod def function_name(self) -> str: - """The name of the generator function to put into df.py.""" + """Get the name of the generator function to put into df.py.""" def name(self) -> str: """ - The name of the generator. + Get the name of the generator. Usually the same as the function name, but can be different to distinguish between generators that have the same function but different queries. @@ -63,7 +61,8 @@ def name(self) -> str: @abstractmethod def nominal_kwargs(self) -> dict[str, str]: """ - The kwargs the generator wants to be called with. + Get the kwargs the generator wants to be called with. + The values will tend to be references to something in the src-stats.yaml file. For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will @@ -74,7 +73,7 @@ def nominal_kwargs(self) -> dict[str, str]: def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: """ - SQL clauses to add to a SELECT ... FROM {table} query. + Get the SQL clauses to add to a SELECT ... FROM {table} query. Will add to SRC_STATS["auto__{table}"] For example { @@ -94,7 +93,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: def custom_queries(self) -> dict[str, dict[str, str]]: """ - SQL queries to add to SRC_STATS. + Get the SQL queries to add to SRC_STATS. Should be used for queries that do not follow the SELECT ... FROM table format using aggregate queries, because these should use select_aggregate_clauses. @@ -114,14 +113,14 @@ def custom_queries(self) -> dict[str, dict[str, str]]: @abstractmethod def actual_kwargs(self) -> dict[str, Any]: """ - The kwargs (summary statistics) this generator is instantiated with. + Get the kwargs (summary statistics) this generator is instantiated with. + + This must match `nominal_kwargs` in structure. """ @abstractmethod def generate_data(self, count: int) -> list[Any]: - """ - Generate 'count' random data points for this column. - """ + """Generate ``count`` random data points for this column.""" def fit(self, default: float = -1) -> float: """ @@ -134,9 +133,7 @@ def fit(self, default: float = -1) -> float: class PredefinedGenerator(Generator): - """ - Generator built from an existing config.yaml. - """ + """Generator built from an existing config.yaml.""" SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") @@ -168,6 +165,7 @@ def __init__( ): """ Initialise a generator from a config.yaml. + :param config: The entire configuration. :param generator_object: The part of the configuration at tables.*.row_generators """ @@ -219,24 +217,30 @@ def __init__( } def function_name(self) -> str: + """Get the name of the generator function to call.""" return self._name def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" return self._kwn def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" return self._select_aggregate_clauses def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" return self._custom_queries def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" # Run the queries from nominal_kwargs # ... logger.error("PredefinedGenerator.actual_kwargs not implemented yet") return {} def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" # Call the function if we can. This could be tricky... # ... logger.error("PredefinedGenerator.generate_data not implemented yet") @@ -244,21 +248,19 @@ def generate_data(self, count: int) -> list[Any]: class GeneratorFactory(ABC): - """ - A factory for making generators appropriate for a database column. - """ + """A factory for making generators appropriate for a database column.""" @abstractmethod def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: - """ - Returns all the generators that might be appropriate for this column. - """ + """Get the generators appropriate to these columns.""" class Buckets: """ + Measured buckets for a real distribution. + Finds the real distribution of continuous data so that we can measure the fit of generators against it. """ @@ -272,6 +274,7 @@ def __init__( stddev: float, count: int, ): + """Initialise a Buckets object.""" with engine.connect() as connection: raw_buckets = connection.execute( text( @@ -330,15 +333,11 @@ def make_buckets( return buckets def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: - """ - Figure out the fit from bucket counts from the generator distribution. - """ + """Figure out the fit from bucket counts from the generator distribution.""" return fit_from_buckets(self.buckets, bucket_counts) def fit_from_values(self, values: list[float]) -> float: - """ - Figure out the fit from samples from the generator distribution. - """ + """Figure out the fit from samples from the generator distribution.""" buckets = [0] * 10 x = self.mean - 2 * self.stddev w = self.stddev / 2 @@ -352,12 +351,14 @@ class MultiGeneratorFactory(GeneratorFactory): """A composite factory.""" def __init__(self, factories: list[GeneratorFactory]): + """Initialise a MultiGeneratorFactory.""" super().__init__() self.factories = factories def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" return [ generator for factory in self.factories @@ -366,12 +367,14 @@ def get_generators( class MimesisGeneratorBase(Generator): + """Base class for a generator using Mimesis.""" + def __init__( self, function_name: str, ): """ - Generator from Mimesis. + Initialise a generator that uses Mimesis. :param function_name: is relative to 'generic', for example 'person.name'. """ @@ -391,13 +394,17 @@ def __init__( self._generator_function = f def function_name(self) -> str: + """Get the name of the generator function to call.""" return self._name def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [self._generator_function() for _ in range(count)] class MimesisGenerator(MimesisGeneratorBase): + """A generator using Mimesis.""" + def __init__( self, function_name: str, @@ -405,7 +412,7 @@ def __init__( buckets: Buckets | None = None, ): """ - Generator from Mimesis. + Initialise a generator using Mimesis. :param function_name: is relative to 'generic', for example 'person.name'. :param value_fn: Function to convert generator output to floats, if needed. The values @@ -423,19 +430,25 @@ def __init__( self._fit = buckets.fit_from_values(samples) def function_name(self) -> str: + """Get the name of the generator function to call.""" return self._name def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return {} def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return {} def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default if self._fit is None else self._fit class MimesisGeneratorTruncated(MimesisGenerator): + """A string generator using Mimesis that must fit within a certain number of characters.""" + def __init__( self, function_name: str, @@ -443,16 +456,20 @@ def __init__( value_fn: Callable[[Any], float] | None = None, buckets: Buckets | None = None, ): + """Initialise a MimesisGeneratorTruncated.""" self._length = length super().__init__(function_name, value_fn, buckets) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.truncated_string" def name(self) -> str: + """Get the name of the generator.""" return f"{self._name} [truncated to {self._length}]" def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "subgen_fn": self._name, "params": {}, @@ -460,6 +477,7 @@ def nominal_kwargs(self) -> dict[str, Any]: } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "subgen_fn": self._name, "params": {}, @@ -467,10 +485,13 @@ def actual_kwargs(self) -> dict[str, Any]: } def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [self._generator_function()[: self._length] for _ in range(count)] class MimesisDateTimeGenerator(MimesisGeneratorBase): + """DateTime generator using Mimesis.""" + def __init__( self, column: Column, @@ -481,6 +502,8 @@ def __init__( end: int, ) -> None: """ + Initialise a MimesisDateTimeGenerator. + :param column: The column to generate into :param function_name: The name of the mimesis function :param min_year: SQL expression extracting the minimum year @@ -499,6 +522,7 @@ def __init__( def make_singleton( _cls, column: Column, engine: Engine, function_name: str ) -> Sequence[Generator]: + """Make the appropriate generation configuration for this column.""" extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -522,18 +546,21 @@ def make_singleton( ] def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "start": self._start, "end": self._end, } def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" return { f"{self._column.name}__start": { "clause": self._min_year, @@ -546,6 +573,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: } def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [ self._generator_function(start=self._start, end=self._end) for _ in range(count) @@ -553,6 +581,7 @@ def generate_data(self, count: int) -> list[Any]: def get_column_type(column: Column) -> TypeEngine: + """Get the type of the column, generic if possible.""" try: return column.type.as_generic() except NotImplementedError: @@ -560,9 +589,7 @@ def get_column_type(column: Column) -> TypeEngine: class MimesisStringGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return strings. - """ + """All Mimesis generators that return strings.""" GENERATOR_NAMES = [ "address.calling_code", @@ -601,6 +628,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -639,13 +667,12 @@ def get_generators( class MimesisFloatGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return floating point numbers. - """ + """All Mimesis generators that return floating point numbers.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -662,13 +689,12 @@ def get_generators( class MimesisDateGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return dates. - """ + """All Mimesis generators that return dates.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -679,13 +705,12 @@ def get_generators( class MimesisDateTimeGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return datetimes. - """ + """All Mimesis generators that return datetimes.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -698,13 +723,12 @@ def get_generators( class MimesisTimeGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return times. - """ + """All Mimesis generators that return times.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -715,13 +739,12 @@ def get_generators( class MimesisIntegerGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return integers. - """ + """All Mimesis generators that return integers.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -732,27 +755,33 @@ def get_generators( def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: + """Calculate the fit by comparing a pair of lists of buckets.""" sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) return sum_diff_squared / (count * count) class ContinuousDistributionGenerator(Generator): + """Base class for generators producing continuous distributions.""" + expected_buckets: Sequence[numeric] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): + """Initialise a ContinuousDistributionGenerator.""" super().__init__() self.table_name = table_name self.column_name = column_name self.buckets = buckets def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" if self.buckets is None: return {} return { @@ -761,6 +790,7 @@ def actual_kwargs(self) -> dict[str, Any]: } def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" clauses = super().select_aggregate_clauses() return { **clauses, @@ -775,12 +805,15 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: } def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) class GaussianGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a Gaussian (normal) distribution.""" + expected_buckets = [ 0.0227, 0.0441, @@ -795,9 +828,11 @@ class GaussianGenerator(ContinuousDistributionGenerator): ] def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.normal" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [ dist_gen.normal(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -805,6 +840,8 @@ def generate_data(self, count: int) -> list[Any]: class UniformGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a uniform distribution.""" + expected_buckets = [ 0, 0.06698, @@ -819,9 +856,11 @@ class UniformGenerator(ContinuousDistributionGenerator): ] def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.uniform_ms" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [ dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -829,9 +868,7 @@ def generate_data(self, count: int) -> list[Any]: class ContinuousDistributionGeneratorFactory(GeneratorFactory): - """ - All generators that want an average and standard deviation. - """ + """All generators that want an average and standard deviation.""" def _get_generators_from_buckets( self, @@ -848,6 +885,7 @@ def _get_generators_from_buckets( def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -865,6 +903,8 @@ def get_generators( class LogNormalGenerator(Generator): + """Generator producing numbers in a log-normal distribution.""" + # TODO: figure out the real buckets here (this was from a random sample in R) expected_buckets = [ 0, @@ -887,6 +927,7 @@ def __init__( logmean: float, logstddev: float, ): + """Initialise a LogNormalGenerator.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -895,24 +936,29 @@ def __init__( self.logstddev = logstddev def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.lognormal" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "logmean": self.logmean, "logsd": self.logstddev, } def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" clauses = super().select_aggregate_clauses() return { **clauses, @@ -927,15 +973,14 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: } def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): - """ - All generators that want an average and standard deviation of log data. - """ + """All generators that want an average and standard deviation of log data.""" def _get_generators_from_buckets( self, @@ -968,8 +1013,8 @@ def _get_generators_from_buckets( def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: """ - Get a zipf distribution for a certain number of items distributed - in a certain number of bins. + Get a zipf distribution for a certain number of items. + :param total: The total number of items to be distributed. :param bins: The total number of bins to distribute the items into. :return: A generator of the number of items in each bin, from the @@ -989,6 +1034,8 @@ def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None class ChoiceGenerator(Generator): + """Base generator for all generators producing choices of items.""" + STORE_COUNTS = False def __init__( @@ -1000,6 +1047,7 @@ def __init__( sample_count: int | None = None, suppress_count: int = 0, ) -> None: + """Initialise a ChoiceGenerator.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -1035,27 +1083,29 @@ def __init__( @abstractmethod def get_estimated_counts(self, counts: list[int]) -> list[int]: - """ - The counts that we would expect if this distribution was the correct one. - """ + """Get the counts that we would expect if this distribution was the correct one.""" def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } def name(self) -> str: + """Get the name of the generator.""" n = super().name() if self._annotation is None: return n return f"{n} [{self._annotation}]" def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "a": self.values, } def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" qs = super().custom_queries() return { **qs, @@ -1066,17 +1116,23 @@ def custom_queries(self) -> dict[str, dict[str, str]]: } def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default if self._fit is None else self._fit class ZipfChoiceGenerator(ChoiceGenerator): + """Generator producing items in a Zipf distribution.""" + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" return list(zipf_distribution(sum(counts), len(counts))) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.zipf_choice" def generate_data(self, count: int) -> list[float]: + """Generate ``count`` random data points for this column.""" return [ dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) ] @@ -1084,7 +1140,8 @@ def generate_data(self, count: int) -> list[float]: def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: """ - A generator putting ``total`` items uniformly into ``bins`` bins. + Construct a distribution putting ``total`` items uniformly into ``bins`` bins. + If they don't fit exactly evenly, the earlier bins will have one more item than the later bins so the total is as required. """ @@ -1097,30 +1154,36 @@ def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, N class UniformChoiceGenerator(ChoiceGenerator): - """ - A generator producing values, each roughly as frequently as each other. - """ + """A generator producing values, each roughly as frequently as each other.""" def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" return list(uniform_distribution(sum(counts), len(counts))) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.choice" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [dist_gen.choice(self.values) for _ in range(count)] class WeightedChoiceGenerator(ChoiceGenerator): + """Choice generator that matches the source data's frequency.""" + STORE_COUNTS = True def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" return counts def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.weighted_choice" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [dist_gen.weighted_choice(self.values) for _ in range(count)] @@ -1143,6 +1206,7 @@ class ValueGatherer: """ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: + """Initialise a ValueGatherer.""" values = [] # All values found counts = [] # The number or each value cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" @@ -1176,9 +1240,7 @@ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: class ChoiceGeneratorFactory(GeneratorFactory): - """ - All generators that want an average and standard deviation. - """ + """All generators that want an average and standard deviation.""" SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 @@ -1186,6 +1248,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -1293,32 +1356,38 @@ def get_generators( class ConstantGenerator(Generator): + """Generator that always produces the same value.""" + def __init__(self, value: Any) -> None: + """Initialise the ConstantGenerator.""" super().__init__() self.value = value self.repr = repr(value) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.constant" def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" return {"value": self.repr} def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return {"value": self.value} def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [self.value for _ in range(count)] class ConstantGeneratorFactory(GeneratorFactory): - """ - Just the null generator - """ + """Just the null generator.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate for these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -1335,9 +1404,7 @@ def get_generators( class MultivariateNormalGenerator(Generator): - """ - Generator of multiple values drawn from a multivariate normal distribution. - """ + """Generator of multiple values drawn from a multivariate normal distribution.""" def __init__( self, @@ -1347,6 +1414,7 @@ def __init__( covariates: RowMapping, function_name: str, ) -> None: + """Initialise a MultivariateNormalGenerator.""" self._table = table_name self._columns = column_names self._query = query @@ -1354,14 +1422,17 @@ def __init__( self._function_name = function_name def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen." + self._function_name def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', } def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" cols = ", ".join(self._columns) return { f"auto__cov__{self._table}": { @@ -1371,32 +1442,34 @@ def custom_queries(self) -> dict[str, Any]: } def actual_kwargs(self) -> dict[str, Any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ + """Get the kwargs (summary statistics) this generator was instantiated with.""" return {"cov": self._covariates} def generate_data(self, count: int) -> list[Any]: - """ - Generate 'count' random data points for this column. - """ + """Generate 'count' random data points for this column.""" return [ getattr(dist_gen, self._function_name)(self._covariates) for _ in range(count) ] def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default class MultivariateNormalGeneratorFactory(GeneratorFactory): + """Normal distribution generator factory.""" + def function_name(self) -> str: + """Get the name of the generator function to call.""" return "multivariate_normal" def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" return column.name + " IS NOT NULL" def query_var(self, column: str) -> str: + """Get the SQL expression of the value to query for this column.""" return column def query( @@ -1411,7 +1484,8 @@ def query( sample_count: int | None = None, ) -> str: """ - Gets a query for the basics for multivariate normal/lognormal parameters. + Get a query for the basics for multivariate normal/lognormal parameters. + :param table: The name of the table to be queried. :param columns: The columns in the multivariate distribution. :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. @@ -1456,6 +1530,7 @@ def query( def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators for these columns.""" # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1487,20 +1562,23 @@ def get_generators( class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Multivariate lognormal generator factory.""" + def function_name(self) -> str: + """Get the name of the generator function to call.""" return "multivariate_lognormal" def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" return f"COALESCE(0 < {column.name}, FALSE)" def query_var(self, column: str) -> str: + """Get the expression to query for, for this column.""" return f"LN({column})" def text_list(items: Iterable[str]) -> str: - """ - Concatenate the items with commas and one "and". - """ + """Concatenate the items with commas and one "and".""" item_i = iter(items) try: last_item = next(item_i) @@ -1518,6 +1596,8 @@ def text_list(items: Iterable[str]) -> str: @dataclass class RowPartition: + """A partition where all the rows have the same pattern of NULLs.""" + query: str # list of numeric columns included_numeric: list[Column] @@ -1537,6 +1617,7 @@ class RowPartition: covariates: Sequence[RowMapping] def comment(self) -> str: + """Make an appropriate comment for this partition.""" caveat = "" if self.included_choice: caveat = f" (for each possible value of {text_list(self.included_choice.values())})" @@ -1584,6 +1665,7 @@ def __init__( partition_counts: Iterable[RowMapping] = [], partition_count_comment: str | None = None, ): + """Initialise a NullPartitionedNormalGenerator.""" self._query_name = query_name self._partitions = partitions self._function_name = function_name @@ -1596,14 +1678,17 @@ def __init__( self._name = f"null-partitioned {function_name}" def name(self) -> str: + """Get the name of the generator.""" return self._name def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.alternatives" def _nominal_kwargs_with_combinations( self, index: int, partition: RowPartition ) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml`` for a single partition.""" count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' if not partition.included_numeric and not partition.included_choice: return { @@ -1634,6 +1719,7 @@ def _count_query_name(self) -> str: return f"auto__cov__{self._query_name}__counts" def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "alternative_configs": [ self._nominal_kwargs_with_combinations(index, self._partitions[index]) @@ -1643,6 +1729,7 @@ def nominal_kwargs(self) -> dict[str, Any]: } def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" partitions = { f"auto__cov__{self._query_name}__alt_{index}": { "comment": partition.comment(), @@ -1690,9 +1777,7 @@ def _actual_kwargs_with_combinations( } def actual_kwargs(self) -> dict[str, Any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "alternative_configs": [ self._actual_kwargs_with_combinations(self._partitions[index]) @@ -1702,31 +1787,29 @@ def actual_kwargs(self) -> dict[str, Any]: } def generate_data(self, count: int) -> list[Any]: - """ - Generate 'count' random data points for this column. - """ + """Generate 'count' random data points for this column.""" kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default def is_numeric(col: Column) -> bool: + """Test if this column stores a numeric value.""" ct = get_column_type(col) return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys def powerset(input: list[T]) -> Iterable[Iterable[T]]: - """Returns a list of all sublists of""" + """Get a list of all sublists of ``input``.""" return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) @dataclass class NullableColumn: - """ - A reference to a nullable column whose nullability is part of a partitioning. - """ + """A reference to a nullable column whose nullability is part of a partitioning.""" column: Column # The bit (power of two) of the number of the partition in the partition sizes list @@ -1734,13 +1817,12 @@ class NullableColumn: class NullPatternPartition: - """ - The definition of a partition (in other words, what makes it not another partition) - """ + """Get the definition of a partition (in other words, what makes it not another partition).""" def __init__( self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] ): + """Initialise a pattern of nulls which can be queried for.""" self.index = sum(nc.bitmask for nc in partition_nonnulls) nonnull_columns = {nc.column.name for nc in partition_nonnulls} self.included_numeric: list[Column] = [] @@ -1772,6 +1854,8 @@ def __init__( class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Produces null partitioned generators, for complex interdependent data.""" + SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 EMPTY_RESULT = [ @@ -1784,24 +1868,22 @@ class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): ] def function_name(self) -> str: + """Get the name of the generator function to call.""" return "grouped_multivariate_normal" def query_predicate(self, column: Column) -> str: - """ - Returns a SQL expression that is true when ``column`` is available for analysis. - """ + """Get a SQL expression that is true when ``column`` is available for analysis.""" if is_numeric(column): # x <> x + 1 ensures that x is not infinity or NaN return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" return f"{column.name} IS NOT NULL" def query_var(self, column: str) -> str: + """Return the expression we are querying for in this column.""" return column def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: - """ - Gets a list of nullable columns together with bitmasks. - """ + """Get a list of nullable columns together with bitmasks.""" out: list[NullableColumn] = [] for col in columns: if col.nullable: @@ -1817,7 +1899,7 @@ def get_partition_count_query( self, ncs: list[NullableColumn], table: str, where: str | None = None ) -> str: """ - Returns a SQL expression returning columns ``count`` and ``index``. + Get a SQL expression returning columns ``count`` and ``index``. Each row returned represents one of the null pattern partitions. ``index`` is the bitmask of all those nullable columns that are not null for @@ -1834,6 +1916,7 @@ def get_partition_count_query( def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get any appropriate generators for these columns.""" if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) @@ -1934,6 +2017,7 @@ def _execute_partition_queries( ) -> bool: """ Execute the query in each partition, filling in the covariates. + :return: True if all the partitions work, False if any of them fail. """ found_nonzero = False @@ -1948,21 +2032,32 @@ def _execute_partition_queries( class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): + """ + A generator for numeric and non-numeric columns. + + Any values could be null, the distributions of the nonnull numeric columns + depend on each other and the other non-numeric column values. + """ + def function_name(self) -> str: + """Get the name of the generator function to call.""" return "grouped_multivariate_lognormal" def query_predicate(self, column: Column) -> str: + """Get the SQL expression testing if the value in this column should be used.""" if is_numeric(column): # x <> x + 1 ensures that x is not infinity or NaN return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" return f"{column.name} IS NOT NULL" def query_var(self, column: str) -> str: + """Get the variable or expression we are querying for this column.""" return f"LN({column})" @lru_cache(1) def everything_factory() -> GeneratorFactory: + """Get a factory that encapsulates all the other factories.""" return MultiGeneratorFactory( [ MimesisStringGeneratorFactory(), diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 580fa3af..d35806d1 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -1,3 +1,4 @@ +"""Interactive configuration commands.""" import cmd import csv import functools @@ -39,11 +40,13 @@ def or_default(v: T | None, d: T) -> T: - """Returns v if it isn't None, otherwise d.""" + """Return v if it isn't None, otherwise d.""" return d if v is None else v class TableType(Enum): + """Types of table to be configured.""" + GENERATE = "generate" IGNORE = "ignore" VOCABULARY = "vocabulary" @@ -70,38 +73,49 @@ class TableType(Enum): @dataclass class TableEntry: + """Base class for table entries for interactive commands.""" + name: str # name of the table class AskSaveCmd(cmd.Cmd): + """Interactive shell for whether to save and quit.""" + intro = "Do you want to save this configuration?" prompt = "(yes/no/cancel) " file = None def __init__(self) -> None: + """Initialise a save command.""" super().__init__() self.result = "" def do_yes(self, _arg: str) -> bool: + """Save the new config.yaml.""" self.result = "yes" return True def do_no(self, _arg: str) -> bool: + """Exit without saving.""" self.result = "no" return True def do_cancel(self, _arg: str) -> bool: + """Do not exit.""" self.result = "cancel" return True def fk_column_name(fk: ForeignKey) -> str: + """Display name for a foreign key.""" if fk_refers_to_ignored_table(fk): return f"{fk.target_fullname} (ignored)" return str(fk.target_fullname) class DbCmd(ABC, cmd.Cmd): + """Base class for interactive configuration commands.""" + INFO_NO_MORE_TABLES = "There are no more tables" ERROR_ALREADY_AT_START = "Error: Already at the start" ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" @@ -110,6 +124,13 @@ class DbCmd(ABC, cmd.Cmd): @abstractmethod def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | None: + """ + Make a table entry suitable for this interactive command. + + :param name: The name of the table to make an entry for. + :param table_config: The part of the ``config.yaml`` referring to this table. + :return: The table entry or None if this table should not be interacted with. + """ ... def __init__( @@ -119,6 +140,7 @@ def __init__( metadata: MetaData, config: MutableMapping[str, Any], ): + """Initialise a DbCmd.""" super().__init__() self.config: MutableMapping[str, Any] = config self.metadata = metadata @@ -138,9 +160,11 @@ def __init__( @property def sync_engine(self) -> Engine: + """Get the synchronous version of the engine.""" return get_sync_engine(self.engine) def __enter__(self) -> Self: + """Enter a ``with`` statement.""" return self def __exit__( @@ -149,12 +173,20 @@ def __exit__( _exc_val: Optional[BaseException], _exc_tb: Optional[TracebackType], ) -> None: + """Dispose of this object.""" self.engine.dispose() def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Print text, formatted with positional and keyword arguments.""" print(text.format(*args, **kwargs)) def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: + """ + Print a table. + + :param headings: List of headings for the table. + :param rows: List of rows of values. + """ output = PrettyTable() output.field_names = headings for row in rows: @@ -162,6 +194,11 @@ def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: print(output) def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: + """ + Print a table. + + :param columns: Dict of column names to the values in the column. + """ output = PrettyTable() row_count = max([len(col) for col in columns.values()]) for field_name, data in columns.items(): @@ -169,11 +206,13 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: print(output) def print_results(self, result: sqlalchemy.CursorResult) -> None: + """Print the rows resulting from a database query.""" self.print_table(list(result.keys()), [list(row) for row in result.all()]) def ask_save(self) -> str: """ Ask the user if they want to save. + :return: ``yes``, ``no`` or ``cancel``. """ ask = AskSaveCmd() @@ -182,11 +221,13 @@ def ask_save(self) -> str: @abstractmethod def set_prompt(self) -> None: + """Set the prompt according to the current state.""" ... def set_table_index(self, index: int) -> bool: """ Move to a different table. + :param index: Index of the table to move to. :return: True if there is a table with such an index to move to. """ @@ -198,7 +239,9 @@ def set_table_index(self, index: int) -> bool: def next_table(self, report: str = "No more tables") -> bool: """ - Move to the next table + Move to the next table. + + :param report: The text to print if there is no next table. :return: True if there is another table to move to. """ if not self.set_table_index(self.table_index + 1): @@ -211,12 +254,15 @@ def table_name(self) -> str: return str(self._table_entries[self.table_index].name) def table_metadata(self) -> Table: + """Get the metadata of the current table.""" return self.metadata.tables[self.table_name()] def get_column_names(self) -> list[str]: + """Get the names of the current columns.""" return [col.name for col in self.table_metadata().columns] def report_columns(self) -> None: + """Print information about the current columns.""" self.print_table( ["name", "type", "primary", "nullable", "foreign key"], [ @@ -232,6 +278,7 @@ def report_columns(self) -> None: ) def get_table_config(self, table_name: str) -> dict[str, Any]: + """Get the configuration of the named table.""" ts = self.config.get("tables", None) if type(ts) is not dict: return {} @@ -239,6 +286,7 @@ def get_table_config(self, table_name: str) -> dict[str, Any]: return t if type(t) is dict else {} def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: + """Set the configuration of the named table.""" ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} @@ -246,6 +294,7 @@ def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: ts[table_name] = config def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: + """Remove all source stats with the given prefix from the configuration.""" src_stats = self.config.get("src-stats", []) new_src_stats = [] for stat in src_stats: @@ -255,6 +304,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: return new_src_stats def get_nonnull_columns(self, table_name: str) -> list[str]: + """Get the names of the nullable columns in the named table.""" metadata_table = self.metadata.tables[table_name] return [ str(name) @@ -263,6 +313,7 @@ def get_nonnull_columns(self, table_name: str) -> list[str]: ] def find_entry_index_by_table_name(self, table_name: str) -> int | None: + """Get the index of the table entry of the named table.""" return next( ( i @@ -273,13 +324,14 @@ def find_entry_index_by_table_name(self, table_name: str) -> int | None: ) def find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + """Get the table entry of the named table.""" for e in self._table_entries: if e.name == table_name: return e return None def do_counts(self, _arg: str) -> None: - "Report the column names with the counts of nulls in them" + """Report the column names with the counts of nulls in them.""" if len(self._table_entries) <= self.table_index: return table_name = self.table_name() @@ -309,7 +361,7 @@ def do_counts(self, _arg: str) -> None: ) def do_select(self, arg: str) -> None: - "Run a select query over the database and show the first 50 results" + """Run a select query over the database and show the first 50 results.""" MAX_SELECT_ROWS = 50 with self.sync_engine.connect() as connection: try: @@ -327,6 +379,8 @@ def do_select(self, arg: str) -> None: def do_peek(self, arg: str) -> None: """ + View some data from the current table. + Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. Use 'peek' to see a sample of the current column(s). Rows that are enitrely null are suppressed. @@ -358,6 +412,7 @@ def do_peek(self, arg: str) -> None: def complete_peek( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Completions for the ``peek`` command.""" if len(self._table_entries) <= self.table_index: return [] return [ @@ -367,11 +422,15 @@ def complete_peek( @dataclass class TableCmdTableEntry(TableEntry): + """Table entry for the table command shell.""" + old_type: TableType new_type: TableType class TableCmd(DbCmd): + """Command shell allowing the user to set the type of each table.""" + intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" doc_leader = """Use the commands 'ignore', 'vocabulary', 'private', 'empty' or 'generate' to set the table's type. Use 'next' or @@ -395,6 +454,13 @@ class TableCmd(DbCmd): NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | None: + """ + Make a table entry for the named table. + + :param name: The name of the table. + :param table: The part of ``config.yaml`` corresponding to this table. + :return: The newly-constructed table entry. + """ if table.get("ignore", False): return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) if table.get("vocabulary_table", False): @@ -412,20 +478,24 @@ def __init__( metadata: MetaData, config: MutableMapping[str, Any], ) -> None: + """Initialise a TableCmd.""" super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @property def table_entries(self) -> list[TableCmdTableEntry]: + """Get the list of table entries.""" return cast(list[TableCmdTableEntry], self._table_entries) def find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + """Get the table entry of the table with the given name.""" entry = super().find_entry_by_table_name(table_name) if entry is None: return None return cast(TableCmdTableEntry, entry) def set_prompt(self) -> None: + """Set the prompt according to the current table and its type.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) @@ -433,11 +503,13 @@ def set_prompt(self) -> None: self.prompt = "(table) " def set_type(self, t_type: TableType) -> None: + """Set the type of the current table.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type def _copy_entries(self) -> None: + """Alter the configuration to match the new table entries.""" for entry in self.table_entries: if entry.old_type != entry.new_type: table = self.get_table_config(entry.name) @@ -470,6 +542,7 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table) def _get_referenced_tables(self, from_table_name: str) -> set[str]: + """Get all the tables referenced by this table's foreign keys.""" from_meta = self.metadata.tables[from_table_name] return { fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys @@ -520,7 +593,7 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: return warnings def do_quit(self, _arg: str) -> bool: - "Check the updates, save them if desired and quit the configurer." + """Check the updates, save them if desired and quit the configurer.""" count = 0 for entry in self.table_entries: if entry.old_type != entry.new_type: @@ -552,7 +625,7 @@ def do_quit(self, _arg: str) -> bool: return False def do_tables(self, _arg: str) -> None: - "list the tables with their types" + """List the tables with their types.""" for entry in self.table_entries: old = entry.old_type new = entry.new_type @@ -560,7 +633,7 @@ def do_tables(self, _arg: str) -> None: self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) def do_next(self, arg: str) -> None: - "'next' = go to the next table, 'next tablename' = go to table 'tablename'" + """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" if arg: # Find the index of the table called _arg, if any index = self.find_entry_index_by_table_name(arg) @@ -574,52 +647,54 @@ def do_next(self, arg: str) -> None: def complete_next( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Get the completions for tables and columns.""" return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] def do_previous(self, _arg: str) -> None: - "Go to the previous table" + """Go to the previous table.""" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def do_ignore(self, _arg: str) -> None: - "Set the current table as ignored, and go to the next table" + """Set the current table as ignored, and go to the next table.""" self.set_type(TableType.IGNORE) self.print("Table {} set as ignored", self.table_name()) self.next_table() def do_vocabulary(self, _arg: str) -> None: - "Set the current table as a vocabulary table, and go to the next table" + """Set the current table as a vocabulary table, and go to the next table.""" self.set_type(TableType.VOCABULARY) self.print("Table {} set to be a vocabulary table", self.table_name()) self.next_table() def do_private(self, _arg: str) -> None: - "Set the current table as a primary private table (such as the table of patients)" + """Set the current table as a primary private table (such as the table of patients).""" self.set_type(TableType.PRIVATE) self.print("Table {} set to be a primary private table", self.table_name()) self.next_table() def do_generate(self, _arg: str) -> None: - "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" + """Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table.""" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() def do_empty(self, _arg: str) -> None: - "Set the current table as empty; no generators will be run for it" + """Set the current table as empty; no generators will be run for it.""" self.set_type(TableType.EMPTY) self.print("Table {} empty", self.table_name()) self.next_table() def do_columns(self, _arg: str) -> None: - "Report the column names and metadata" + """Report the column names and metadata.""" self.report_columns() def do_data(self, arg: str) -> None: """ Report some data. + 'data' = report a random ten lines, 'data 20' = report a random 20 lines, 'data 20 ColumnName' = report a random twenty entries from ColumnName, @@ -660,6 +735,7 @@ def do_data(self, arg: str) -> None: def complete_data( self, text: str, line: str, begidx: int, _endidx: int ) -> list[str]: + """Get completions for arguments to ``data``.""" previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] @@ -667,6 +743,13 @@ def complete_data( return [k for k in table_metadata.columns.keys() if k.startswith(text)] def print_column_data(self, column: str, count: int, min_length: int) -> None: + """ + Print a sample of data from a certain column of the current table. + + :param column: The name of the column to report on. + :param count: The number of rows to sample. + :param min_length: The minimum length of text to choose from (0 for any text). + """ where = f"WHERE {column} IS NOT NULL" if 0 < min_length: where = "WHERE LENGTH({column}) >= {len}".format( @@ -687,6 +770,11 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: self.columnize([str(x[0]) for x in result.all()]) def print_row_data(self, count: int) -> None: + """ + Print a sample or rows from the current table. + + :param count: The number of rows to report. + """ with self.sync_engine.connect() as connection: result = connection.execute( text( @@ -705,6 +793,7 @@ def print_row_data(self, count: int) -> None: def update_config_tables( src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping ) -> Mapping[str, Any]: + """Ask the user to specify what should happen to each table.""" with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() return tc.config @@ -712,6 +801,8 @@ def update_config_tables( @dataclass class MissingnessType: + """The functions required for applying missingness.""" + SAMPLED = "column_presence.sampled" SAMPLED_QUERY = ( "SELECT COUNT(*) AS row_count, {result_names} FROM " @@ -725,6 +816,14 @@ class MissingnessType: @classmethod def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: + """ + Construct a query to make a sampling of the named rows of the table. + + :param table: The name of the table to sample. + :param count: The number of samples to get. + :param column_names: The columns to fetch. + :return: The SQL query to do the sampling. + """ result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) column_is_nulls = ", ".join( ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] @@ -739,11 +838,19 @@ def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> s @dataclass class MissingnessCmdTableEntry(TableEntry): + """Table entry for the missingness command shell.""" + old_type: MissingnessType new_type: MissingnessType | None class MissingnessCmd(DbCmd): + """ + Interactive shell for the user to set missingness. + + Can only be used for Missingness Completely At Random. + """ + intro = "Interactive missingness configuration. Type ? for help.\n" doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for the current table. Use commands 'next' and 'previous' to change the @@ -774,6 +881,13 @@ def find_missingness_query( def make_table_entry( self, name: str, table: Mapping ) -> MissingnessCmdTableEntry | None: + """ + Make a table entry for a particular table. + + :param name: The name of the table to make an entry for. + :param table: The part of ``config.yaml`` relating to this table. + :return: The newly-constructed table entry. + """ if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -822,6 +936,7 @@ def __init__( ): """ Initialise a MissingnessCmd. + :param src_dsn: connection string for the source database. :param src_schema: schema name for the source database. :param metadata: SQLAlchemy metadata for the source database. @@ -832,20 +947,20 @@ def __init__( @property def table_entries(self) -> list[MissingnessCmdTableEntry]: + """Get the table entries list.""" return cast(list[MissingnessCmdTableEntry], self._table_entries) def find_entry_by_table_name( self, table_name: str ) -> MissingnessCmdTableEntry | None: + """Find the table entry given the table name.""" entry = super().find_entry_by_table_name(table_name) if entry is None: return None return cast(MissingnessCmdTableEntry, entry) def set_prompt(self) -> None: - """ - Sets the prompt according to the current table and missingness. - """ + """Set the prompt according to the current table and missingness.""" if self.table_index < len(self.table_entries): entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] nt = entry.new_type @@ -857,11 +972,13 @@ def set_prompt(self) -> None: self.prompt = "(missingness) " def set_type(self, t_type: MissingnessType) -> None: + """Set the missingness of the current table.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type def _copy_entries(self) -> None: + """Set the new missingness into the configuration.""" src_stats = self._remove_prefix_src_stats("missing_auto__") for entry in self.table_entries: table = self.get_table_config(entry.name) @@ -892,7 +1009,7 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table) def do_quit(self, _arg: str) -> bool: - "Check the updates, save them if desired and quit the configurer." + """Check the updates, save them if desired and quit the configurer.""" count = 0 for entry in self.table_entries: if entry.old_type != entry.new_type: @@ -927,7 +1044,7 @@ def do_quit(self, _arg: str) -> bool: return False def do_tables(self, _arg: str) -> None: - "list the tables with their types" + """List the tables with their types.""" for entry in self.table_entries: old = "-" if entry.old_type is None else entry.old_type.name new = "-" if entry.new_type is None else entry.new_type.name @@ -935,7 +1052,11 @@ def do_tables(self, _arg: str) -> None: self.print("{0} {1}", entry.name, desc) def do_next(self, arg: str) -> None: - "'next' = go to the next table, 'next tablename' = go to table 'tablename'" + """ + Go to the next table, or a specified table. + + 'next' = go to the next table, 'next tablename' = go to table 'tablename' + """ if arg: # Find the index of the table called _arg, if any index = next( @@ -952,19 +1073,18 @@ def do_next(self, arg: str) -> None: def complete_next( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Get completions for tables and columns.""" return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] def do_previous(self, _arg: str) -> None: - "Go to the previous table" + """Go to the previous table.""" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def _set_type(self, name: str, query: str, comment: str) -> None: - """ - Set the current table entry's query. - """ + """Set the current table entry's query.""" if len(self.table_entries) <= self.table_index: return entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] @@ -976,9 +1096,7 @@ def _set_type(self, name: str, query: str, comment: str) -> None: ) def _set_none(self) -> None: - """ - Sets the current table to have no missingness applied. - """ + """Set the current table to have no missingness applied.""" if len(self.table_entries) <= self.table_index: return self.table_entries[self.table_index].new_type = None @@ -986,9 +1104,10 @@ def _set_none(self) -> None: def do_sampled(self, arg: str) -> None: """ Set the current table missingness as 'sampled', and go to the next table. - "sampled 3000" means sample 3000 rows at random and choose the missingness - to be the same as one of those 3000 at random. - "sampled" means the same, but with a default number of rows sampled (1000). + + 'sampled 3000' means sample 3000 rows at random and choose the + missingness to be the same as one of those 3000 at random. + 'sampled' means the same, but with a default number of rows sampled (1000). """ if len(self.table_entries) <= self.table_index: self.print("Error! not on a table") @@ -1017,7 +1136,7 @@ def do_sampled(self, arg: str) -> None: self.next_table() def do_none(self, _arg: str) -> None: - "Set the current table to have no missingness, and go to the next table" + """Set the current table to have no missingness, and go to the next table.""" self._set_none() self.print("Table {} set to have no missingness", self.table_name()) self.next_table() @@ -1029,6 +1148,16 @@ def update_missingness( metadata: MetaData, config: MutableMapping[str, Any], ) -> Mapping[str, Any]: + """ + Ask the user to update the missingness information in ``config.yaml``. + + :param src_dsn: The connection string for the source database. + :param src_schema: The name of the source database schema (or None + for the default). + :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. + :param config: The starting configuration, + :return: The updated configuration. + """ with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() return mc.config @@ -1036,9 +1165,7 @@ def update_missingness( @dataclass class GeneratorInfo: - """ - A generator and the columns it assigns to. - """ + """A generator and the columns it assigns to.""" columns: list[str] gen: Generator | None @@ -1048,6 +1175,7 @@ class GeneratorInfo: class GeneratorCmdTableEntry(TableEntry): """ List of generators set for a table. + Includes the original setting and the currently configured generators. """ @@ -1057,9 +1185,7 @@ class GeneratorCmdTableEntry(TableEntry): class GeneratorCmd(DbCmd): - """ - Interactive command shell for setting generators. - """ + """Interactive command shell for setting generators.""" intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the @@ -1095,6 +1221,13 @@ class GeneratorCmd(DbCmd): def make_table_entry( self, table_name: str, table: Mapping ) -> GeneratorCmdTableEntry | None: + """ + Make a table entry. + + :param table_name: The name of the table. + :param table: The portion of the ``config.yaml`` file describing this table. + :return: The newly constructed table entry, or None if this table is to be ignored. + """ if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -1169,7 +1302,8 @@ def __init__( config: MutableMapping[str, Any], ) -> None: """ - Initialise a GeneratorCmd + Initialise a ``GeneratorCmd``. + :param src_dsn: connection address for source database :param src_schema: database schema name :param metadata: SQLAlchemy metadata for the source database @@ -1182,11 +1316,18 @@ def __init__( @property def table_entries(self) -> list[GeneratorCmdTableEntry]: + """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" return cast(list[GeneratorCmdTableEntry], self._table_entries) def find_entry_by_table_name( self, table_name: str ) -> GeneratorCmdTableEntry | None: + """ + Find the table entry by name. + + :param table_name: The name of the table to find. + :return: The table entry, or None if no such table name exists. + """ entry = super().find_entry_by_table_name(table_name) if entry is None: return None @@ -1194,7 +1335,8 @@ def find_entry_by_table_name( def set_table_index(self, index: int) -> bool: """ - Moves to a new table. + Move to a new table. + :param index: table index to move to. """ ret = super().set_table_index(index) @@ -1206,6 +1348,7 @@ def set_table_index(self, index: int) -> bool: def previous_table(self) -> bool: """ Move to the table before the current one. + :return: True if there is a previous table to go to. """ ret = self.set_table_index(self.table_index - 1) @@ -1223,17 +1366,13 @@ def previous_table(self) -> bool: return ret def get_table(self) -> GeneratorCmdTableEntry | None: - """ - Get the current table entry. - """ + """Get the current table entry.""" if self.table_index < len(self.table_entries): return self.table_entries[self.table_index] return None def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: - """ - Gets a pair; the table name then the generator information. - """ + """Get a pair; the table name then the generator information.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] if self.generator_index < len(entry.new_generators): @@ -1242,25 +1381,19 @@ def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: return (None, None) def get_column_names(self) -> list[str]: - """ - Gets the (unqualified) names for all the current columns. - """ + """Get the (unqualified) names for all the current columns.""" (_, generator_info) = self.get_table_and_generator() return generator_info.columns if generator_info else [] def column_metadata(self) -> list[Column]: - """ - Gets the metadata for all the current columns. - """ + """Get the metadata for all the current columns.""" table = self.table_metadata() if table is None: return [] return [table.columns[name] for name in self.get_column_names()] def set_prompt(self) -> None: - """ - Set the prompt according to the current table, column and generator. - """ + """Set the prompt according to the current table, column and generator.""" (table_name, gen_info) = self.get_table_and_generator() if table_name is None: self.prompt = "(generators) " @@ -1277,14 +1410,15 @@ def set_prompt(self) -> None: def _remove_auto_src_stats(self) -> list[dict[str, Any]]: """ - Remove all automatic source stats (which we assume is - every source stats query whose name begins with ``auto__`)""" + Remove all automatic source stats. + + We assume every source stats query whose name begins with ``auto__` + :return: The new ``src_stats`` configuration. + """ return self._remove_prefix_src_stats("auto__") def _copy_entries(self) -> None: - """ - Set generator and query information in the configuration. - """ + """Set generator and query information in the configuration.""" src_stats = self._remove_auto_src_stats() for entry in self.table_entries: rgs = [] @@ -1344,7 +1478,7 @@ def _find_old_generator( return None def do_quit(self, arg: str) -> bool: - "Check the updates, save them if desired and quit the configurer." + """Check the updates, save them if desired and quit the configurer.""" count = 0 for entry in self.table_entries: header_shown = False @@ -1377,7 +1511,7 @@ def do_quit(self, arg: str) -> bool: return False def do_tables(self, arg: str) -> None: - "list the tables" + """List the tables.""" for t_entry in self.table_entries: entry = cast(GeneratorCmdTableEntry, t_entry) gen_count = len(entry.new_generators) @@ -1385,7 +1519,7 @@ def do_tables(self, arg: str) -> None: self.print("{0} ({1})", entry.name, how_many) def do_list(self, arg: str) -> None: - "list the generators in the current table" + """List the generators in the current table.""" if len(self.table_entries) <= self.table_index: self.print("Error: no table {0}", self.table_index) return @@ -1408,11 +1542,11 @@ def do_list(self, arg: str) -> None: self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) def do_columns(self, _arg: str) -> None: - "Report the column names and metadata" + """Report the column names and metadata.""" self.report_columns() def do_info(self, _arg: str) -> None: - "Show information about the current column" + """Show information about the current column.""" for cm in self.column_metadata(): self.print( "Column {0} in table {1} has type {2} ({3}).", @@ -1436,12 +1570,21 @@ def do_info(self, _arg: str) -> None: ) def _get_table_index(self, table_name: str) -> int | None: + """Get the index of the named table in the table entries list.""" for n, entry in enumerate(self.table_entries): if entry.name == table_name: return n return None def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + """ + Get the index number of a column within the list of generators in this table. + + :param table_index: The index of the table in which to search. + :param column_name: The name of the column to search for. + :return: The index in the ``new_generators`` attribute of the table entry + containing the specified column, or None if this does not exist. + """ entry = self.table_entries[table_index] for n, gen in enumerate(entry.new_generators): if column_name in gen.columns: @@ -1449,6 +1592,11 @@ def _get_generator_index(self, table_index: int, column_name: str) -> int | None return None def go_to(self, target: str) -> bool: + """ + Go to a particular column. + + :return: True on success. + """ parts = target.split(".", 1) table_index = self._get_table_index(parts[0]) if table_index is None: @@ -1474,10 +1622,11 @@ def go_to(self, target: str) -> bool: def do_next(self, arg: str) -> None: """ - Go to the next generator. - Or go to a named table: 'next tablename'. - Or go to a column: 'next tablename.columnname'. - Or go to a column within this table: 'next columnname'. + Go to the next generator. or a specified generator. + + Go to a named table: 'next tablename', + go to a column: 'next tablename.columnname', + or go to a column within this table: 'next columnname'. """ if arg: self.go_to(arg) @@ -1485,13 +1634,15 @@ def do_next(self, arg: str) -> None: self._go_next() def do_n(self, arg: str) -> None: - """Synonym for next""" + """Go to the next generator, or a specified generator.""" self.do_next(arg) def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: + """Complete the ``n`` command's arguments.""" return self.complete_next(text, line, begidx, endidx) def _go_next(self) -> None: + """Go to the next column.""" table = self.get_table() if table is None: self.print("No more tables") @@ -1506,6 +1657,7 @@ def _go_next(self) -> None: def complete_next( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Completions for the arguments of the ``next`` command.""" parts = text.split(".", 1) first_part = parts[0] if 1 < len(parts): @@ -1540,7 +1692,7 @@ def complete_next( return table_names + column_names def do_previous(self, _arg: str) -> None: - """Go to the previous generator""" + """Go to the previous generator.""" if self.generator_index == 0: self.previous_table() else: @@ -1548,23 +1700,18 @@ def do_previous(self, _arg: str) -> None: self.set_prompt() def do_b(self, arg: str) -> None: - """Synonym for previous""" + """Synonym for previous.""" self.do_previous(arg) def _generators_valid(self) -> bool: - """ - Return True if the self.generators property is still correct for the - table and columns currently being examined. - """ + """Test if ``self.generators`` is still correct for the current columns.""" return self.generators_valid_columns == ( self.table_index, self.get_column_names(), ) def _get_generator_proposals(self) -> list[Generator]: - """ - Get a list of acceptable generators, sorted by decreasing fit to the actual data. - """ + """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" if not self._generators_valid(): self.generators = None if self.generators is None: @@ -1579,9 +1726,7 @@ def _get_generator_proposals(self) -> list[Generator]: return self.generators def _print_privacy(self) -> None: - """ - Print the privacy status of the current table. - """ + """Print the privacy status of the current table.""" table = self.table_metadata() if table is None: return @@ -1630,6 +1775,10 @@ def do_c(self, arg: str) -> None: def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: """ Print the values queried from the database for this generator. + + :param table_name: The name of the table the generator applies to. + :param n: A number to print at the start of the output. + :param gen: The generator to report. """ if not gen.select_aggregate_clauses() and not gen.custom_queries(): self.print( @@ -1649,6 +1798,8 @@ def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None def _print_custom_queries(self, gen: Generator) -> None: """ Print all the custom queries and all the values they get in this case. + + :param gen: The generator to print the custom queries for. """ cqs = gen.custom_queries() if not cqs: @@ -1705,7 +1856,12 @@ def _get_aggregate_query( def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: """ - Prints the select aggregate query and all the values it gets in this case. + Print the select aggregate query and all the values it gets in this case. + + This is not the entire query that will be executed, but only the part of it + that is required by a certain generator. + :param table_name: The table name. + :param gen: The generator to limit the aggregate query to. """ sacs = gen.select_aggregate_clauses() if not sacs: @@ -1793,10 +1949,7 @@ def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: return None def do_set(self, arg: str) -> None: - """ - Set one of the proposals as a generator. - :param arg: A single integer (as a string). - """ + """Set one of the proposals as a generator.""" if arg.isdigit() and not self._generators_valid(): self.print("Please run 'propose' before 'set '") return @@ -1847,7 +2000,6 @@ def do_merge(self, arg: str) -> None: Add this column(s) to the specified column(s). After this, one generator will cover them all. - :param arg: space separated list of column names to merge. """ cols = arg.split() if not cols: @@ -1980,6 +2132,7 @@ def update_config_generators( ) -> Mapping[str, Any]: """ Update configuration with the specification from a CSV file. + The specification is a headerless CSV file with columns: Table name, Column name (or space-separated list of column names), Generator name required, Second choice generator name, Third choice generator diff --git a/datafaker/main.py b/datafaker/main.py index 454cf446..42f6bdea 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -70,9 +70,22 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: - with open(orm_file_name) as orm_fh: +def load_metadata_config( + orm_file_name: str, config: dict | None = None +) -> dict[str, Any]: + """ + Load the ``orm.yaml`` file, returning a dict representation. + + :param orm_file_name: The name of the file to load. + :param config: The ``config.yaml`` file object. Ignored tables will be + excluded from the output. + :return: A dict representing the ``orm.yaml`` file, with the tables + the ``config`` says to ignore removed. + """ + with open(orm_file_name, encoding="utf-8") as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) + if not isinstance(meta_dict, dict): + return {} tables_dict = meta_dict.get("tables", {}) if config is not None and "tables" in config: # Remove ignored tables @@ -84,7 +97,8 @@ def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: """ - Load metadata from ``orm.yaml`` + Load metadata from ``orm.yaml``. + :param orm_file_name: ``orm.yaml`` or alternative name to load metadata from. :param config: Used to exclude tables that are marked as ``ignore: true``. :return: SQLAlchemy MetaData object representing the database described by the loaded file. @@ -94,9 +108,7 @@ def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> Any: - """ - Load metadata excluding any foreign keys pointing to ignored tables. - """ + """Load metadata excluding any foreign keys pointing to ignored tables.""" meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, config) @@ -105,6 +117,7 @@ def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> def main( verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") ) -> None: + """Set the global parameters.""" conf_logger(verbose) @@ -327,7 +340,10 @@ def make_stats( def make_tables( config_file: Optional[str] = Option( None, - help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables", + help=( + "The configuration file, used if you want" + " an orm.yaml lacking data for the ignored tables" + ), ), orm_file: str = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), force: bool = Option( @@ -361,9 +377,7 @@ def configure_tables( ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: - """ - Interactively set tables to ignored, vocabulary or primary private. - """ + """Interactively set tables to ignored, vocabulary or primary private.""" logger.debug("Configuring tables in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -393,9 +407,7 @@ def configure_missing( ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: - """ - Interactively set the missingness of the generated data. - """ + """Interactively set the missingness of the generated data.""" logger.debug("Configuring missingness in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -405,7 +417,7 @@ def configure_missing( config_any = yaml.load( config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) - if type(config_any) is dict: + if isinstance(config_any, dict): config = config_any metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) @@ -428,9 +440,7 @@ def configure_generators( help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", ), ) -> None: - """ - Interactively set generators for column data. - """ + """Interactively set generators for column data.""" logger.debug("Configuring generators in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -468,7 +478,7 @@ def dump_data( schema_name = settings.dst_schema config = read_config_file(config_file) if config_file is not None else {} metadata = load_metadata_for_output(orm_file, config) - if output == None: + if output is None: if isinstance(sys.stdout, io.TextIOBase): dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return @@ -555,6 +565,8 @@ def remove_tables( class TableType(str, Enum): + """Types of tables for the ``list-tables`` command.""" + ALL = "all" VOCAB = "vocab" GENERATED = "generated" diff --git a/datafaker/make.py b/datafaker/make.py index e9db8636..096ee8bf 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -75,7 +75,7 @@ class RowGeneratorInfo: @dataclass class ColumnChoice: - """Chooses columns based on a random number in [0,1)""" + """Choose columns based on a random number in [0,1).""" function_name: str argument_values: list[str] @@ -84,6 +84,14 @@ class ColumnChoice: def make_column_choices( table_config: Mapping[str, Any], ) -> list[ColumnChoice]: + """ + Convert ``missingness_generators`` from ``config.yaml`` into functions to call. + + :param table_config: The ``tables`` part of ``config.yaml``. + :return: A list of ``ColumnChoice`` objects; that is, descriptions of + functions and their arguments to call to reveal a list of columns that + should have values generated for them. + """ return [ ColumnChoice( function_name=mg["name"], @@ -97,9 +105,9 @@ def make_column_choices( @dataclass class _PrimaryConstraint: """ - Describes a Uniqueness constraint for when multiple - columns in a table comprise the primary key. Not a - real constraint, but enough to write df.py. + Describes a Uniqueness constraint for a multi-column primary key. + + Not a real constraint, but enough to write df.py. """ columns: list[Column] @@ -252,8 +260,10 @@ def _get_default_generator(column: Column) -> RowGeneratorInfo: def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of a generator and maybe arguments - that limit its range to the permitted scale. + Get the default generator name and arguments. + + :param column: The column to get the generator for. + :return: The name of a generator and its arguments. """ column_type = column.type scale = getattr(column_type, "scale", None) @@ -270,8 +280,10 @@ def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: def _string_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of a string generator and maybe arguments - that limit its length. + Get the name of the default string generator for a column. + + :param column: The column to get the generator for. + :return: The name of the generator and its arguments. """ column_size: Optional[int] = getattr(column.type, "length", None) if column_size is None: @@ -281,7 +293,11 @@ def _string_generator(column: Column) -> tuple[str, dict[str, str]]: def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of an integer generator. + Get the name of the default integer generator. + + :param column: The column to get the generator for. + :return: A pair consisting of the name of a generator and its + arguments. """ if not column.primary_key: return ("generic.numeric.integer_number", {}) @@ -302,6 +318,8 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: @dataclass class GeneratorInfo: + """Description of a generator.""" + # Name or function to generate random objects of this type (not using summary data) generator: str | Callable[[Column], tuple[str, dict[str, str]]] # SQL query that gets the data to supply as arguments to the generator @@ -321,8 +339,9 @@ def get_result_mappings( info: GeneratorInfo, results: CursorResult ) -> dict[str, Any] | None: """ - Gets a mapping from the results of a database query as a Python - dictionary converted according to the GeneratorInfo provided. + Get a mapping from the results of a database query. + + :return: A Python dictionary converted according to the GeneratorInfo provided. """ kw: dict[str, Any] = {} mapping = results.mappings().first() @@ -379,7 +398,7 @@ def get_result_mappings( def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: """ - Gets a generator from a column type. + Get a generator from a column type. Returns either a string representing the callable, or a callable that, given the column.type will return a tuple (string representing generator @@ -400,7 +419,7 @@ def _get_generator_for_column( column_t: type, ) -> str | Callable[[Column], tuple[str, dict[str, str]]] | None: """ - Gets a generator from a column type. + Get a generator from a column type. Returns either a string representing the callable, or a callable that, given the column.type will return a tuple (string representing generator @@ -412,8 +431,9 @@ def _get_generator_for_column( def _get_generator_and_arguments(column: Column) -> tuple[str | None, dict[str, str]]: """ - Gets the generator and its arguments from the column type, returning - a tuple of a string representing the generator callable and a dict of + Get the generator and its arguments from the column type. + + :return: A tuple of a string representing the generator callable and a dict of keyword arguments to supply to it. """ generator_function = _get_generator_for_column(type(column.type)) @@ -540,10 +560,7 @@ def make_vocabulary_tables( compress: bool, table_names: set[str] | None = None, ) -> None: - """ - Extracts the data from the source database for each - vocabulary table. - """ + """Extract the data from the source database for each vocabulary table.""" settings = get_settings() src_dsn: str = settings.src_dsn or "" assert src_dsn != "", "Missing SRC_DSN setting." @@ -660,9 +677,7 @@ def generate_df_content(template_context: Mapping[str, Any]) -> str: def _get_generator_for_existing_vocabulary_table( table: Table, ) -> VocabularyTableGeneratorInfo: - """ - Turns an existing vocabulary YAML file into a VocabularyTableGeneratorInfo. - """ + """Turn an existing vocabulary YAML file into a VocabularyTableGeneratorInfo.""" return VocabularyTableGeneratorInfo( dictionary_entry=table.name, variable_name=f"{table.name.lower()}_vocab", @@ -676,9 +691,7 @@ def _generate_vocabulary_table( overwrite_files: bool = False, compress: bool = False, ) -> None: - """ - Pulls data out of the source database to make a vocabulary YAML file - """ + """Pull data out of the source database to make a vocabulary YAML file.""" yaml_file_name: str = table.fullname + ".yaml" if compress: yaml_file_name += ".gz" @@ -692,9 +705,7 @@ def _generate_vocabulary_table( def make_tables_file( db_dsn: str, schema_name: Optional[str], config: Mapping[str, Any] ) -> str: - """ - Construct the YAML file representing the schema. - """ + """Construct the YAML file representing the schema.""" tables_config = config.get("tables", {}) engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) @@ -726,6 +737,8 @@ def reflect_if(table_name: str, _: Any) -> bool: class DbConnection: + """A connection to a database.""" + def __init__(self, engine: MaybeAsyncEngine) -> None: """ Initialise an unopened database connection. @@ -736,6 +749,7 @@ def __init__(self, engine: MaybeAsyncEngine) -> None: self._connection: Connection | AsyncConnection async def __aenter__(self) -> Self: + """Enter the ``with`` section, opening a connection.""" if isinstance(self._engine, AsyncEngine): self._connection = await self._engine.connect() else: @@ -748,16 +762,19 @@ async def __aexit__( _value: Optional[BaseException], _tb: Optional[TracebackType], ) -> None: + """Exit the ``with`` section, closing the connection.""" if isinstance(self._connection, AsyncConnection): await self._connection.close() self._connection.close() async def execute_raw_query(self, query: Executable) -> CursorResult: + """Execute the query on the owned connection.""" if isinstance(self._connection, AsyncConnection): return await self._connection.execute(query) return self._connection.execute(query) async def table_row_count(self, table_name: str) -> int: + """Count the number of rows in the named table.""" with await self.execute_raw_query( text(f"SELECT COUNT(*) FROM {table_name}") ) as result: @@ -790,12 +807,14 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: def fix_type(value: Any) -> Any: + """Make this value suitable for yaml output.""" if type(value) is decimal.Decimal: return float(value) return value def fix_types(dics: list[dict]) -> list[dict]: + """Make all the items in this list suitable for yaml output.""" return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] @@ -827,6 +846,7 @@ async def make_src_stats_connection( ) -> dict[str, dict[str, Any]]: """ Make the ``src-stats.yaml`` file given the database connection to read from. + :param config: configuration from ``config.yaml``. :param db_conn: Source database connection. :param metadata: Source database metadata from ``orm.yaml``. diff --git a/datafaker/providers.py b/datafaker/providers.py index 65abf069..75006c7e 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -38,7 +38,7 @@ def increment(self, db_connection: Connection, column: Column) -> int: """Return incrementing value for the column specified.""" name = f"{column.table.name}.{column.name}" result = self.accumulators.get(name, None) - if result == None: + if result is None: row = db_connection.execute(select(func.max(column))).first() result = 0 if row is None or row[0] is None else row[0] value = result + 1 diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 0b96e346..b7f5cf2d 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,6 +1,6 @@ """Convert between a Python dict describing a database schema and a SQLAlchemy MetaData.""" import typing -from typing import Callable +from functools import partial import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table @@ -27,9 +27,7 @@ def simple(type_: type) -> ParserType: def integer() -> ParserType: - """ - Get a parser for an integer, outputting that integer. - """ + """Get a parser for an integer, outputting that integer.""" return parsy.regex(r"-?[0-9]+").map(int) @@ -164,6 +162,7 @@ def type_parser() -> ParserType: def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: """ Produce a dict description of a column. + :param column: The SQLAlchemy column to translate. :param dialect: The SQL dialect in which to render the type name. """ @@ -192,10 +191,11 @@ def dict_to_column( table_name: str, col_name: str, rep: dict, - ignore_fk: Callable[[str], bool], + ignore_fk: typing.Callable[[str], bool], ) -> Column: """ Produce column from aspects of its dict description. + :param table_name: The name of the table the column appears in. :param col_name: The name of the column. :param rep: The dict description of the column. @@ -249,7 +249,7 @@ def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: def table_to_dict(table: Table, dialect: Dialect) -> TableT: - """Converts a SQL Alchemy Table object into a Python dict.""" + """Convert a SQL Alchemy Table object into a Python dict.""" return { "columns": { str(column.key): column_to_dict(column, dialect) @@ -267,7 +267,7 @@ def dict_to_table( name: str, meta: MetaData, table_dict: TableT, - ignore_fk: Callable[[str], bool], + ignore_fk: typing.Callable[[str], bool], ) -> Table: """Create a Table from its description.""" return Table( @@ -285,8 +285,9 @@ def metadata_to_dict( meta: MetaData, schema_name: str | None, engine: Engine ) -> dict[str, typing.Any]: """ - Converts a SQL Alchemy MetaData object into - a Python object ready for conversion to YAML. + Convert a metadata object into a Python dict. + + The output will be ready for output to ``orm.yaml``. """ return { "tables": { @@ -298,10 +299,13 @@ def metadata_to_dict( } -def should_ignore_fk(fk: str, tables_dict: dict[str, TableT]) -> bool: +def should_ignore_fk(tables_dict: dict[str, TableT], fk: str) -> bool: """ - Tell if this foreign key should be ignored because it points to an - ignored table. + Test if this foreign key points to an ignored table. + + If so, this foreign key should be ignored. + :param tables_dict: The ``tables`` value from ``config.yaml``. + :param fk: The name of the foreign key. """ fk_bits = fk.split(".", 2) if len(fk_bits) != 2: @@ -311,9 +315,13 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, TableT]) -> bool: return bool(tables_dict[fk_bits[0]].get("ignore", False)) +def _always_false(_: str) -> bool: + return False + + def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaData: """ - Converts a dict to a SQL Alchemy MetaData object. + Convert a dict to a SQL Alchemy MetaData object. :param config_for_output: The configuration object. Should be None if the metadata object is being used for connecting to the source database. @@ -322,11 +330,12 @@ def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaDa constraint to an ignored table. """ tables_dict = obj.get("tables", {}) + ignore_fk: typing.Callable[[str], bool] if config_for_output and "tables" in config_for_output: tables_config = config_for_output["tables"] - ignore_fk = lambda fk: should_ignore_fk(fk, tables_config) + ignore_fk = partial(should_ignore_fk, tables_config) else: - ignore_fk = lambda _: False + ignore_fk = _always_false meta = MetaData() for k, td in tables_dict.items(): dict_to_table(k, meta, td, ignore_fk) diff --git a/datafaker/utils.py b/datafaker/utils.py index b34664c2..009109d2 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -50,7 +50,6 @@ ) T = TypeVar("T") -_K = TypeVar("_K") def read_config_file(path: str) -> dict: @@ -96,14 +95,29 @@ def import_file(file_path: str) -> ModuleType: def open_file(file_name: str | Path) -> io.BufferedWriter: + """Open a file for writing.""" return Path(file_name).open("wb") def open_compressed_file(file_name: str | Path) -> gzip.GzipFile: + """ + Open a gzip-compressed file for writing. + + :param file_name: The name of the file to open. + :return: A file object; it can be written to as a normal uncompressed + file and it will do the compression. + """ return gzip.GzipFile(file_name, "wb") def table_row_count(table: Table, conn: Connection) -> int: + """ + Count the rows in the table. + + :param table: The table to count. + :param conn: The connection to the database. + :return: The number of rows in the table. + """ return conn.execute( select(sqlalchemy.func.count()).select_from( sqlalchemy.table( @@ -222,10 +236,12 @@ def warning_or_higher(record: logging.LogRecord) -> bool: class StdoutHandler(logging.Handler): """ A handler that writes to stdout. + We aren't using StreamHandler because that confuses typer.testing.CliRunner """ def flush(self) -> None: + """Flush the buffer.""" self.acquire() try: sys.stdout.flush() @@ -233,6 +249,7 @@ def flush(self) -> None: self.release() def emit(self, record: Any) -> None: + """Write the record.""" try: msg = self.format(record) sys.stdout.write(msg + "\n") @@ -246,10 +263,12 @@ def emit(self, record: Any) -> None: class StderrHandler(logging.Handler): """ A handler that writes to stderr. + We aren't using StreamHandler because that confuses typer.testing.CliRunner """ def flush(self) -> None: + """Flush the buffer.""" self.acquire() try: sys.stderr.flush() @@ -257,6 +276,7 @@ def flush(self) -> None: self.release() def emit(self, record: Any) -> None: + """Write the record.""" try: msg = self.format(record) sys.stderr.write(msg + "\n") @@ -293,19 +313,37 @@ def conf_logger(verbose: bool) -> None: logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) -def get_flag(maybe_dict: Any, key: Any) -> Any: - """Returns maybe_dict[key] or False if that doesn't exist""" - return type(maybe_dict) is dict and maybe_dict.get(key, False) +def get_flag(maybe_dict: Any, key: Any) -> bool: + """ + Get a boolean from a mapping, or False if that does not make sense. + + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :return: True only if ``maybe_dict`` is a mapping, ``maybe_dict[key]`` + exists and ``maybe_dict[key]`` is truthy. + """ + return isinstance(maybe_dict, Mapping) and maybe_dict.get(key, False) + +def get_property(maybe_dict: Any, key: Any, default: T) -> T: + """ + Get a specific property from a dict or a default if that does not exist. -def get_property(maybe_dict: Mapping[_K, Any], key: _K, default: T) -> T: - """Returns maybe_dict[key] or default if that doesn't exist""" - return maybe_dict.get(key, default) if type(maybe_dict) is dict else default + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :param default: The return value if ``maybe_dict`` is not a mapping, + or if ``key`` is not a key of ``maybe_dict``. + :return: ``maybe_dict[key]`` if this makes sense, or ``default`` if not. + """ + return maybe_dict.get(key, default) if isinstance(maybe_dict, Mapping) else default def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: """ - Does this foreign key refer to a table that is configured as ignore in config.yaml + Test if this foreign key refers to an ignored table. + + :param fk: The foreign key to test. + :return: True if the table referred to is ignored in ``config.yaml``. """ try: fk.column @@ -316,7 +354,10 @@ def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: """ - Does this foreign key constraint refer to a table that is configured as ignore in config.yaml + Test if the constraint refers to a table marked as ignored in ``config.yaml``. + + :param fk: The foreign key constraint. + :return: True if ``fk`` refers to an ignored table. """ try: fk.referred_table @@ -328,6 +369,10 @@ def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: def get_related_table_names(table: Table) -> set[str]: """ Get the names of all tables for which there exist foreign keys from this table. + + :param table: SQLAlchemy table object. + :return: The set of the names of the tables referred to by foreign keys + in ``table``. """ return { str(fk.referred_table.name) @@ -338,8 +383,11 @@ def get_related_table_names(table: Table) -> set[str]: def table_is_private(config: Mapping, table_name: str) -> bool: """ - Return True if the table with name table_name is a primary private table - according to config. + Test if the named table is private. + + :param config: The ``config.yaml`` object. + :param table_name: The name of the table to test. + :return: True if the table is marked as private in ``config``. """ ts = config.get("tables", {}) if type(ts) is not dict: @@ -351,10 +399,14 @@ def table_is_private(config: Mapping, table_name: str) -> bool: def primary_private_fks(config: Mapping, table: Table) -> list[str]: """ - Returns the list of columns in the table that refer to primary private tables. + Get the list of columns in the table that refer to primary private tables. A table that is not primary private but has a non-empty list of primary_private_fks is secondary private. + + :param config: The ``config.yaml`` object. + :param table: The table to examine. + :return: A list of names of columns that refer to private tables. """ return [ str(fk.referred_table.name) @@ -365,9 +417,7 @@ def primary_private_fks(config: Mapping, table: Table) -> list[str]: def get_vocabulary_table_names(config: Mapping) -> set[str]: - """ - Extract the table names with a vocabulary_table: true property. - """ + """Extract the table names with a vocabulary_table: true property.""" return { table_name for (table_name, table_config) in config.get("tables", {}).items() @@ -376,6 +426,7 @@ def get_vocabulary_table_names(config: Mapping) -> set[str]: def make_foreign_key_name(table_name: str, col_name: str) -> str: + """Make a suitable foreign key name.""" return f"{table_name}_{col_name}_fkey" @@ -384,6 +435,16 @@ def remove_vocab_foreign_key_constraints( config: Mapping[str, Any], dst_engine: Connection | Engine, ) -> None: + """ + Remove the foreign key constraints from vocabulary tables. + + This allows vocabulary tables to be loaded without worrying about + topologically sorting them or circular dependencies. + + :param metadata: The SQLAlchemy metadata from ``orm.yaml``. + :param config: The ``config.yaml`` object. + :param dst_engine: The destination database or a connection to it. + """ vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] @@ -419,6 +480,7 @@ def reinstate_vocab_foreign_key_constraints( ) -> None: """ Put the removed foreign keys back into the destination database. + :param metadata: The SQLAlchemy metadata for the destination database. :param meta_dict: The ``orm.yaml`` configuration that ``metadata`` was created from. @@ -525,6 +587,14 @@ def topological_sort( def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Table]: + """ + Get the list of non-vocabulary tables, topologically sorted. + + :param metadata: SQLAlchemy database description. + :param config: The ``config.yaml`` object. + :return: The list of non-vocabulary tables, ordered such that the targets + of all the foreign keys come before their sources. + """ table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) @@ -537,8 +607,9 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta def underline_error(e: SyntaxError) -> str: - """ + r""" Make an underline for this error. + :return: string beginning ``\n`` then spaces then ``^^^^`` underlining the error, or a null string if this was not possible. """ @@ -553,7 +624,11 @@ def underline_error(e: SyntaxError) -> str: def generators_require_stats(config: Mapping) -> bool: """ - Returns true if any of the arguments for any of the generators reference SRC_STATS. + Test if the generator references ``SRC_STATS``. + + :param config: ``config.yaml`` object. + :return: True if any of the arguments for any of the generators + reference ``SRC_STATS``. """ ois = { f"object_instantiation.{k}": call diff --git a/tests/test_dump.py b/tests/test_dump.py index 7033e18f..2340a6cc 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -18,11 +18,11 @@ class DumpTests(RequiresDBTestCase): @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: """Test dump-data.""" - TEST_OUTPUT_FILE = io.StringIO() + test_output_file = io.StringIO() metadata = MetaData() metadata.reflect(self.sync_engine) - dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) - make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) + dump_db_tables(metadata, self.dsn, self.schema_name, "player", test_output_file) + make_csv_writer.assert_called_once_with(test_output_file) make_csv_writer.assert_has_calls( [ call().writerow(["id", "given_name", "family_name"]), diff --git a/tests/test_functional.py b/tests/test_functional.py index e60baa1e..792dc40a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -349,33 +349,50 @@ def test_workflow_maximal_args(self) -> None: ) self.assertEqual("", completed_process.stderr) self.assertEqual( - { - "Creating data.", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.full_row_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for table 'data_type_test'", - "Generating data for table 'no_pk_test'", - "Generating data for table 'person'", - "Generating data for table 'strange_type_table'", - "Generating data for table 'unique_constraint_test'", - "Generating data for table 'unique_constraint_test2'", - "Generating data for table 'test_entity'", - "Generating data for table 'hospital_visit'", - "Data created in 2 passes.", - f"person: {2*(3+1+2+2)} rows created.", - f"hospital_visit: {2*(2*2+3)} rows created.", - "data_type_test: 2 rows created.", - "no_pk_test: 2 rows created.", - "strange_type_table: 2 rows created.", - "unique_constraint_test: 2 rows created.", - "unique_constraint_test2: 2 rows created.", - "test_entity: 2 rows created.", - }, - set(completed_process.stdout.split("\n")) - {""}, + sorted( + [ + "Creating data.", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for table 'data_type_test'", + "Generating data for table 'data_type_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'person'", + "Generating data for table 'person'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'test_entity'", + "Generating data for table 'test_entity'", + "Generating data for table 'hospital_visit'", + "Generating data for table 'hospital_visit'", + "Data created in 2 passes.", + f"person: {2*(3+1+2+2)} rows created.", + f"hospital_visit: {2*(2*2+3)} rows created.", + "data_type_test: 2 rows created.", + "no_pk_test: 2 rows created.", + "strange_type_table: 2 rows created.", + "unique_constraint_test: 2 rows created.", + "unique_constraint_test2: 2 rows created.", + "test_entity: 2 rows created.", + "", + ] + ), + sorted(completed_process.stdout.split("\n")), ) completed_process = self.invoke( @@ -452,8 +469,20 @@ def test_workflow_maximal_args(self) -> None: ) def invoke( - self, *args: Any, expected_error: str | None = None, env: Mapping[str, str] = {} + self, + *args: Any, + expected_error: str | None = None, + env: Mapping[str, str] | None = None, ) -> Result: + """ + Run datafaker with the given arguments and environment. + + :param args: Arguments to provide to datafaker. + :param expected_error: If None, will assert that the invocation + passes successfully without throwing an exception. Otherwise, + the suggested error must be present in the standard error stream. + :param env: The environment variables to be set during invocation. + """ res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index a7803f01..c5e79d34 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -3,7 +3,7 @@ import random import re from dataclasses import dataclass -from typing import Any, Iterable, Mapping, MutableMapping +from typing import Any, Iterable, MutableMapping from unittest.mock import MagicMock, Mock, patch from sqlalchemy import insert, select @@ -46,7 +46,7 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: """Capture the printed table.""" self.columns = columns - def columnize(self, items: list[str] | None, displaywidth: int = 80) -> None: + def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: """Capture the printed table.""" if items is not None: self.column_items.append(items) @@ -587,7 +587,10 @@ def test_set_generator_distribution(self) -> None: self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", + ( + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column})" + f" AS stddev__{column} FROM {table}" + ), ) def test_set_generator_distribution_directly(self) -> None: @@ -608,7 +611,10 @@ def test_set_generator_distribution_directly(self) -> None: self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", + ( + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column})" + f" AS stddev__{column} FROM {table}" + ), ) def test_set_generator_choice(self) -> None: @@ -642,7 +648,11 @@ def test_set_generator_choice(self) -> None: ) self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT {column} AS value FROM {table} WHERE {column} IS NOT NULL GROUP BY value ORDER BY COUNT({column}) DESC", + ( + f"SELECT {column} AS value FROM {table}" + f" WHERE {column} IS NOT NULL" + f" GROUP BY value ORDER BY COUNT({column}) DESC" + ), ) def test_weighted_choice_generator_generates_choices(self) -> None: @@ -761,7 +771,10 @@ def test_old_generators_remain(self) -> None: "src-stats": [ { "name": "auto__string", - "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + "query": ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), } ], } @@ -798,7 +811,10 @@ def test_old_generators_remain(self) -> None: self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") self.assertEqual( gc.config["src-stats"][0]["query"], - "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), ) def test_aggregate_queries_merge(self) -> None: @@ -824,7 +840,10 @@ def test_aggregate_queries_merge(self) -> None: "src-stats": [ { "name": "auto__string", - "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + "query": ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), } ], } @@ -1712,7 +1731,7 @@ class NonInteractiveTests(RequiresDBTestCase): ), ) def test_non_interactive_configure_generators( - self, mock_csv_reader: MagicMock, mock_path: MagicMock + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock ) -> None: """ test that we can set generators from a CSV file diff --git a/tests/test_main.py b/tests/test_main.py index a37eaf4f..a570fe4d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -194,7 +194,7 @@ def test_create_generators_with_force_enabled( @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") def test_create_tables( - self, mock_load_meta: MagicMock, mock_config: MagicMock, mock_create: MagicMock + self, mock_load_meta: MagicMock, _mock_config: MagicMock, mock_create: MagicMock ) -> None: """Test the create-tables sub-command.""" diff --git a/tests/test_remove.py b/tests/test_remove.py index 24286fba..a6dbb85c 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -18,12 +18,14 @@ class RemoveThingsTestCase(RequiresDBTestCase): schema_name = "public" def count_rows(self, connection: Connection, table_name: str) -> int | None: + """Count the rows in a table.""" return connection.execute( select(func.count()).select_from(self.metadata.tables[table_name]) ).scalar() @patch("datafaker.remove.get_settings") def test_remove_data(self, mock_get_settings: MagicMock) -> None: + """Test that data can be removed from non-vocabulary tables.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, @@ -67,6 +69,7 @@ def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: @patch("datafaker.remove.get_settings") def test_remove_vocab(self, mock_get_settings: MagicMock) -> None: + """Test that vocabulary tables can be removed.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, @@ -114,6 +117,7 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: @patch("datafaker.remove.get_settings") def test_remove_tables(self, mock_get_settings: MagicMock) -> None: + """Test that destination tables can be removed.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, diff --git a/tests/test_rst.py b/tests/test_rst.py index 090658c1..78f07472 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -7,6 +7,10 @@ from restructuredtext_lint import lint_file +def _level_to_string(level: int) -> str: + return ["Severe", "Error", "Warning"][level] + + class RstTests(TestCase): """Linting for the doc .rst files.""" @@ -44,7 +48,11 @@ def test_dir(self) -> None: ] if filtered_errors: - self.fail(msg="\n".join([ - f"{err.source}({err.line}): {["Severe", "Error", "Warning"][err.level]}: {err.full_message}" - for err in filtered_errors - ])) + self.fail( + msg="\n".join( + [ + f"{err.source}({err.line}): {_level_to_string(err.level)}: {err.full_message}" + for err in filtered_errors + ] + ) + ) From b86e10604ec85b08f073867d990096867c9f5fe2 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 9 Oct 2025 19:17:30 +0100 Subject: [PATCH 18/44] More cleaning --- datafaker/generators.py | 6 +++--- datafaker/main.py | 7 +++++-- datafaker/utils.py | 12 ++++++------ tests/test_interactive.py | 6 +++--- tests/test_rst.py | 20 ++++++++++++-------- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/datafaker/generators.py b/datafaker/generators.py index 4e421af1..a925b79f 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -21,7 +21,7 @@ from datafaker.base import DistributionGenerator from datafaker.utils import T, logger -numeric = Union[int, float] +NumericT = Union[int, float] # How many distinct values can we have before we consider a # choice distribution to be infeasible? @@ -754,7 +754,7 @@ def get_generators( return [MimesisGenerator("person.weight")] -def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: +def fit_from_buckets(xs: Sequence[NumericT], ys: Sequence[NumericT]) -> float: """Calculate the fit by comparing a pair of lists of buckets.""" sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) @@ -764,7 +764,7 @@ def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: class ContinuousDistributionGenerator(Generator): """Base class for generators producing continuous distributions.""" - expected_buckets: Sequence[numeric] = [] + expected_buckets: Sequence[NumericT] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): """Initialise a ContinuousDistributionGenerator.""" diff --git a/datafaker/main.py b/datafaker/main.py index 42f6bdea..c1299830 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -437,7 +437,10 @@ def configure_generators( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), spec: Path = Option( None, - help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", + help=( + "CSV file (headerless) with fields table-name," + " column-name, generator-name to set non-interactively" + ), ), ) -> None: """Interactively set generators for column data.""" @@ -482,7 +485,7 @@ def dump_data( if isinstance(sys.stdout, io.TextIOBase): dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return - with open(output, "wt", newline="") as out: + with open(output, "wt", newline="", encoding="utf-8") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) diff --git a/datafaker/utils.py b/datafaker/utils.py index 009109d2..61089ab3 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -390,11 +390,11 @@ def table_is_private(config: Mapping, table_name: str) -> bool: :return: True if the table is marked as private in ``config``. """ ts = config.get("tables", {}) - if type(ts) is not dict: + if not isinstance(ts, Mapping): return False t = ts.get(table_name, {}) ret = t.get("primary_private", False) - return ret if type(ret) is bool else False + return ret if isinstance(ret, bool) else False def primary_private_fks(config: Mapping, table: Table) -> list[str]: @@ -466,7 +466,7 @@ def remove_vocab_foreign_key_constraints( ) except ProgrammingError as e: session.rollback() - if type(e.orig) is UndefinedObject: + if isinstance(e.orig, UndefinedObject): logger.debug("Constraint does not exist") else: raise e @@ -501,7 +501,7 @@ def reinstate_vocab_foreign_key_constraints( name=make_foreign_key_name(vocab_table_name, column_name), refcolumns=fk_targets, ) - logger.debug(f"Restoring foreign key constraint {fk.name}") + logger.debug("Restoring foreign key constraint %s", fk.name) with Session(dst_engine) as session: session.begin() vocab_table.append_constraint(fk) @@ -598,12 +598,12 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) - (sorted, cycles) = topological_sort( + (sorted_tables, cycles) = topological_sort( table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: logger.warning(f"Cycle detected between tables: {cycle}") - return [metadata.tables[tn] for tn in sorted] + return [metadata.tables[tn] for tn in sorted_tables] def underline_error(e: SyntaxError) -> str: diff --git a/tests/test_interactive.py b/tests/test_interactive.py index c5e79d34..1082c337 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -46,10 +46,10 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: """Capture the printed table.""" self.columns = columns - def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: + def columnize(self, list: list[str] | None, _displaywidth: int = 80) -> None: """Capture the printed table.""" - if items is not None: - self.column_items.append(items) + if list is not None: + self.column_items.append(list) def ask_save(self) -> str: """Quitting always works without needing to ask the user.""" diff --git a/tests/test_rst.py b/tests/test_rst.py index 78f07472..29bf971e 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -2,15 +2,26 @@ The CLI does not allow errors to be disabled, but we can ignore them here.""" from pathlib import Path +from typing import Any from unittest import TestCase from restructuredtext_lint import lint_file def _level_to_string(level: int) -> str: + """Get a string description of an error level.""" return ["Severe", "Error", "Warning"][level] +def _error_message(lint_error: Any) -> str: + """Turn a linting error into an error message.""" + source = getattr(lint_error, "source") + line = getattr(lint_error, "line") + level = _level_to_string(getattr(lint_error, "level")) + message = getattr(lint_error, "full_message") + return f"{source}({line}): {level}: {message}" + + class RstTests(TestCase): """Linting for the doc .rst files.""" @@ -48,11 +59,4 @@ def test_dir(self) -> None: ] if filtered_errors: - self.fail( - msg="\n".join( - [ - f"{err.source}({err.line}): {_level_to_string(err.level)}: {err.full_message}" - for err in filtered_errors - ] - ) - ) + self.fail(msg="\n".join(map(_error_message, filtered_errors))) From e1dec20231818a1af7e706f0c9fd187bd292612f Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 10 Oct 2025 18:44:37 +0100 Subject: [PATCH 19/44] Lots of pylint cleaning --- datafaker/base.py | 27 ++- datafaker/create.py | 9 +- datafaker/generators.py | 216 +++++++++++++++------- datafaker/interactive.py | 226 ++++++++++++----------- datafaker/main.py | 4 +- datafaker/make.py | 40 ++--- datafaker/utils.py | 24 +-- tests/test_functional.py | 36 ++-- tests/test_interactive.py | 318 +++++++++++++++------------------ tests/test_main.py | 21 +-- tests/test_make.py | 8 +- tests/test_providers.py | 3 +- tests/test_remove.py | 5 +- tests/test_rst.py | 2 +- tests/test_unique_generator.py | 1 - tests/utils.py | 39 ++-- 16 files changed, 522 insertions(+), 457 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 4ceb2aff..6ff1890e 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -5,10 +5,11 @@ import os import random from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass +from io import TextIOWrapper from pathlib import Path -from typing import Any, Callable, Generator +from typing import Any, Generator import numpy as np import yaml @@ -59,8 +60,7 @@ def merge_with_constants( yield xs[xi] xi += 1 outi += 1 - for x in xs[xi:]: - yield x + yield from xs[xi:] class NothingToGenerateException(Exception): @@ -132,7 +132,7 @@ def choice(self, a: list[T]) -> T: :return: The chosen value. """ c = random.choice(a) - return c["value"] if type(c) is dict and "value" in c else c + return c["value"] if isinstance(c, Mapping) and "value" in c else c def zipf_choice(self, a: list[T], n: int | None = None) -> T: """ @@ -149,7 +149,7 @@ def zipf_choice(self, a: list[T], n: int | None = None) -> T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] - return c["value"] if type(c) is dict and "value" in c else c + return c["value"] if isinstance(c, Mapping) and "value" in c else c def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ @@ -403,17 +403,26 @@ def load(self, connection: Connection, base_path: Path = Path(".")) -> None: """Load the data from file.""" yaml_file = base_path / Path(self.table.fullname + ".yaml") if yaml_file.exists(): - opener = lambda: open(yaml_file, mode="r", encoding="utf-8") + + def opener() -> TextIOWrapper: + return open(yaml_file, mode="r", encoding="utf-8") + else: yaml_file = base_path / Path(self.table.fullname + ".yaml.gz") if yaml_file.exists(): - opener = lambda: gzip.open(yaml_file, mode="rt") + + def opener() -> TextIOWrapper: + return gzip.open(yaml_file, mode="rt") + else: logger.warning("File %s not found. Skipping...", yaml_file) return if 0 < table_row_count(self.table, connection): logger.warning( - "Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", + ( + "Table %s already contains data" + " (consider running 'datafaker remove-vocab'), skipping..." + ), self.table.name, ) return diff --git a/datafaker/create.py b/datafaker/create.py index f11a0dd3..ce2a74e0 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -235,11 +235,10 @@ def next(self) -> None: if self._final_values is None: self._table_name, self._provided_values = next(self._story) return - else: - self._table_name, self._provided_values = self._story.send( - self._final_values - ) - return + self._table_name, self._provided_values = self._story.send( + self._final_values + ) + return except StopIteration: try: name, self._story = next(self._stories) diff --git a/datafaker/generators.py b/datafaker/generators.py index a925b79f..0b760e45 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -21,7 +21,7 @@ from datafaker.base import DistributionGenerator from datafaker.utils import T, logger -NumericT = Union[int, float] +NumericType = Union[int, float] # How many distinct values can we have before we consider a # choice distribution to be infeasible? @@ -102,7 +102,8 @@ def custom_queries(self) -> dict[str, dict[str, str]]: "query": "SELECT one, too AS two FROM mytable WHERE too > 1", "comment": "big enough one and two from table mytable" }} - will populate SRC_STATS["myquery"]["results"][0]["one"] and SRC_STATS["myquery"]["results"][0]["two"] + will populate SRC_STATS["myquery"]["results"][0]["one"] + and SRC_STATS["myquery"]["results"][0]["two"] in the src-stats.yaml file. Keys should be chosen to minimize the chances of clashing with other queries, @@ -142,18 +143,17 @@ class PredefinedGenerator(Generator): def _get_src_stats_mentioned(self, val: Any) -> set[str]: if not val: return set() - if type(val) is str: + if isinstance(val, str): ss = self.SRC_STAT_NAME_RE.match(val) if ss: ss_name = ss.group(1) logger.debug("Found SRC_STATS reference %s", ss_name) return set([ss_name]) - else: - logger.debug("Value %s does not seem to be a SRC_STATS reference", val) - return set() - if type(val) is list: + logger.debug("Value %s does not seem to be a SRC_STATS reference", val) + return set() + if isinstance(val, list): return set.union(*(self._get_src_stats_mentioned(v) for v in val)) - if type(val) is dict: + if isinstance(val, dict): return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) return set() @@ -278,12 +278,9 @@ def __init__( with engine.connect() as connection: raw_buckets = connection.execute( text( - "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( - column=column_name, - table=table_name, - x=mean - 2 * stddev, - w=stddev / 2, - ) + f"SELECT COUNT({column_name}) AS f," + f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" + f" FROM {table_name} GROUP BY b" ) ) self.buckets: Sequence[int] = [0] * 10 @@ -310,10 +307,9 @@ def make_buckets( with engine.connect() as connection: result = connection.execute( text( - "SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( - table=table_name, - column=column_name, - ) + f"SELECT AVG({column_name}) AS mean," + f" STDDEV({column_name}) AS stddev," + f" COUNT({column_name}) AS count FROM {table_name}" ) ).first() if result is None or result.stddev is None or getattr(result, "count") < 2: @@ -388,7 +384,8 @@ def __init__( f = getattr(f, part) if not callable(f): raise Exception( - f"Mimesis object {function_name} is not a callable, so cannot be used as a generator" + f"Mimesis object {function_name} is not a callable," + " so cannot be used as a generator" ) self._name = "generic." + function_name self._generator_function = f @@ -520,7 +517,7 @@ def __init__( @classmethod def make_singleton( - _cls, column: Column, engine: Engine, function_name: str + cls, column: Column, engine: Engine, function_name: str ) -> Sequence[Generator]: """Make the appropriate generation configuration for this column.""" extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" @@ -548,8 +545,14 @@ def make_singleton( def nominal_kwargs(self) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml``.""" return { - "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', - "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', + "start": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__start"]' + ), + "end": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__end"]' + ), } def actual_kwargs(self) -> dict[str, Any]: @@ -564,11 +567,17 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: return { f"{self._column.name}__start": { "clause": self._min_year, - "comment": f"Earliest year found for column {self._column.name} in table {self._column.table.name}", + "comment": ( + f"Earliest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), }, f"{self._column.name}__end": { "clause": self._max_year, - "comment": f"Latest year found for column {self._column.name} in table {self._column.table.name}", + "comment": ( + f"Latest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), }, } @@ -642,7 +651,7 @@ def get_generators( f"LENGTH({column.name})", ) fitness_fn = len - except Exception as exc: + except Exception: # Some column types that appear to be strings (such as enums) # cannot have their lengths measured. In this case we cannot # detect fitness using lengths. @@ -754,7 +763,7 @@ def get_generators( return [MimesisGenerator("person.weight")] -def fit_from_buckets(xs: Sequence[NumericT], ys: Sequence[NumericT]) -> float: +def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: """Calculate the fit by comparing a pair of lists of buckets.""" sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) @@ -764,7 +773,7 @@ def fit_from_buckets(xs: Sequence[NumericT], ys: Sequence[NumericT]) -> float: class ContinuousDistributionGenerator(Generator): """Base class for generators producing continuous distributions.""" - expected_buckets: Sequence[NumericT] = [] + expected_buckets: Sequence[NumericType] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): """Initialise a ContinuousDistributionGenerator.""" @@ -776,8 +785,14 @@ def __init__(self, table_name: str, column_name: str, buckets: Buckets): def nominal_kwargs(self) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml``.""" return { - "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', - "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', + "mean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["mean__{self.column_name}"]' + ), + "sd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["stddev__{self.column_name}"]' + ), } def actual_kwargs(self) -> dict[str, Any]: @@ -946,8 +961,14 @@ def generate_data(self, count: int) -> list[Any]: def nominal_kwargs(self) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml``.""" return { - "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', - "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', + "logmean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logmean__{self.column_name}"]' + ), + "logsd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logstddev__{self.column_name}"]' + ), } def actual_kwargs(self) -> dict[str, Any]: @@ -963,12 +984,21 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: return { **clauses, f"logmean__{self.column_name}": { - "clause": f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name}) ELSE NULL END)", + "clause": ( + f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" + " ELSE NULL END)" + ), "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", }, f"logstddev__{self.column_name}": { - "clause": f"STDDEV(CASE WHEN 0<{self.column_name} THEN LN({self.column_name}) ELSE NULL END)", - "comment": f"Standard deviation of logs of {self.column_name} from table {self.table_name}", + "clause": ( + f"STDDEV(CASE WHEN 0<{self.column_name}" + f" THEN LN({self.column_name}) ELSE NULL END)" + ), + "comment": ( + f"Standard deviation of logs of {self.column_name}" + f" from table {self.table_name}" + ), }, } @@ -992,10 +1022,10 @@ def _get_generators_from_buckets( with engine.connect() as connection: result = connection.execute( text( - "SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( - table=table_name, - column=column_name, - ) + f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" + " ELSE NULL END) AS logmean," + f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" + f" AS logstddev FROM {table_name}" ) ).first() if result is None or result.logstddev is None: @@ -1064,21 +1094,56 @@ def __init__( extra_comment = " and their counts" if suppress_count == 0: if sample_count is None: - self._query = f"SELECT {column_name} AS value{extra_results} FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY COUNT({column_name}) DESC" - self._comment = f"All the values{extra_comment} that appear in column {column_name} of table {table_name}" + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL GROUP BY value" + f" ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name}" + ) self._annotation = None else: - self._query = f"SELECT {column_name} AS value{extra_results} FROM (SELECT {column_name} FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" - self._comment = f"The values{extra_comment} that appear in column {column_name} of a random sample of {sample_count} rows of table {table_name}" + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM" + f" (SELECT {column_name} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL" + f" ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"The values{extra_comment} that appear in column {column_name}" + f" of a random sample of {sample_count} rows of table {table_name}" + ) self._annotation = "sampled" else: if sample_count is None: - self._query = f"SELECT value{extra_expo} FROM (SELECT {column_name} AS value, COUNT({column_name}) AS count FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count" - self._comment = f"All the values{extra_comment} that appear in column {column_name} of table {table_name} more than {suppress_count} times" + self._query = ( + f"SELECT value{extra_expo} FROM" + f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" + f" FROM {table_name} WHERE {column_name} IS NOT NULL" + f" GROUP BY value ORDER BY count DESC) AS _inner" + f" WHERE {suppress_count} < count" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name} more than {suppress_count} times" + ) self._annotation = "suppressed" else: - self._query = f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM (SELECT {column_name} AS value FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count" - self._comment = f"The values{extra_comment} that appear more than {suppress_count} times in column {column_name}, out of a random sample of {sample_count} rows of table {table_name}" + self._query = ( + f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" + f" (SELECT {column_name} AS value FROM {table_name}" + f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY count DESC)" + f" AS _inner WHERE {suppress_count} < count" + ) + self._comment = ( + f"The values{extra_comment} that appear more than {suppress_count} times" + f" in column {column_name}, out of a random sample of {sample_count} rows" + f" of table {table_name}" + ) self._annotation = "sampled and suppressed" @abstractmethod @@ -1220,14 +1285,14 @@ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: if c != 0: counts.append(c) v = result.v - if type(v) is decimal.Decimal: + if isinstance(v, decimal.Decimal): v = float(v) values.append(v) cvs.append({"value": v, "count": c}) if suppress_count < c: counts_not_suppressed.append(c) v = result.v - if type(v) is decimal.Decimal: + if isinstance(v, decimal.Decimal): v = float(v) values_not_suppressed.append(v) cvs_not_suppressed.append({"value": v, "count": c}) @@ -1258,11 +1323,9 @@ def get_generators( with engine.connect() as connection: results = connection.execute( text( - "SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( - table=table_name, - column=column_name, - limit=MAXIMUM_CHOICES + 1, - ) + f"SELECT {column_name} AS v, COUNT({column_name})" + f" AS f FROM {table_name} GROUP BY v" + f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" ) ) if results is not None and results.rowcount <= MAXIMUM_CHOICES: @@ -1281,11 +1344,10 @@ def get_generators( ] results = connection.execute( text( - "SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( - table=table_name, - column=column_name, - sample_count=self.SAMPLE_COUNT, - ) + f"SELECT v, COUNT(v) AS f FROM" + f" (SELECT {column_name} as v FROM {table_name}" + f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" + f" AS _inner GROUP BY v ORDER BY f DESC" ) ) if results is not None: @@ -1436,7 +1498,10 @@ def custom_queries(self) -> dict[str, Any]: cols = ", ".join(self._columns) return { f"auto__cov__{self._table}": { - "comment": f"Means and covariate matrix for the columns {cols}, so that we can produce the relatedness between these in the fake data.", + "comment": ( + f"Means and covariate matrix for the columns {cols}," + " so that we can produce the relatedness between these in the fake data." + ), "query": self._query, } } @@ -1511,14 +1576,20 @@ def query( ) means = "".join(f", _q.m{i}" for i in range(len(columns))) covs = "".join( - f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ( + f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" + f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ) for iy in range(len(columns)) for ix in range(iy + 1) ) if sample_count is None: subquery = table + where else: - subquery = f"(SELECT * FROM {table}{where} ORDER BY RANDOM() LIMIT {sample_count}) AS _sampled" + subquery = ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) # if there are any numeric columns we need at least two rows to make any (co)variances at all suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" return ( @@ -1689,7 +1760,10 @@ def _nominal_kwargs_with_combinations( self, index: int, partition: RowPartition ) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml`` for a single partition.""" - count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + count = ( + 'sum(r["count"] for r in' + f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + ) if not partition.included_numeric and not partition.included_choice: return { "count": count, @@ -1799,12 +1873,12 @@ def fit(self, default: float = -1) -> float: def is_numeric(col: Column) -> bool: """Test if this column stores a numeric value.""" ct = get_column_type(col) - return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys + return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys -def powerset(input: list[T]) -> Iterable[Iterable[T]]: +def powerset(xs: list[T]) -> Iterable[Iterable[T]]: """Get a list of all sublists of ``input``.""" - return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) + return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) @dataclass @@ -1911,7 +1985,11 @@ def get_partition_count_query( ) if where is None: return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' - return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' + return ( + 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' + f' {index_exp} AS "index"' + f' FROM {table} GROUP BY "index") AS _q {where}' + ) def get_generators( self, columns: list[Column], engine: Engine @@ -1973,7 +2051,11 @@ def get_generators( partition_count_max_results = ( connection.execute(text(partition_query_max)).mappings().fetchall() ) - count_comment = f"Number of rows for each combination of the columns { {nc.column.name for nc in nullable_columns} } of the table {table} being null" + count_comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table} being null" + ) if self._execute_partition_queries(connection, row_partitions_maximal): gens.append( NullPartitionedNormalGenerator( diff --git a/datafaker/interactive.py b/datafaker/interactive.py index d35806d1..17efa199 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -4,7 +4,7 @@ import functools import re from abc import ABC, abstractmethod -from collections.abc import Mapping, MutableMapping +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from enum import Enum from pathlib import Path @@ -13,7 +13,7 @@ import sqlalchemy from prettytable import PrettyTable -from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table, text +from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table from typing_extensions import Self from datafaker.generators import Generator, PredefinedGenerator, everything_factory @@ -123,7 +123,9 @@ class DbCmd(ABC, cmd.Cmd): ROW_COUNT_MSG = "Total row count: {}" @abstractmethod - def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | None: + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableEntry | None: """ Make a table entry suitable for this interactive command. @@ -131,7 +133,6 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | Non :param table_config: The part of the ``config.yaml`` referring to this table. :return: The table entry or None if this table should not be interacted with. """ - ... def __init__( self, @@ -145,12 +146,12 @@ def __init__( self.config: MutableMapping[str, Any] = config self.metadata = metadata self._table_entries: list[TableEntry] = [] - tables_config: Mapping = config.get("tables", {}) - if type(tables_config) is not dict: + tables_config: MutableMapping = config.get("tables", {}) + if not isinstance(tables_config, MutableMapping): tables_config = {} for name in metadata.tables.keys(): table_config = tables_config.get(name, {}) - if type(table_config) is not dict: + if not isinstance(table_config, MutableMapping): table_config = {} entry = self.make_table_entry(name, table_config) if entry is not None: @@ -224,7 +225,7 @@ def set_prompt(self) -> None: """Set the prompt according to the current state.""" ... - def set_table_index(self, index: int) -> bool: + def _set_table_index(self, index: int) -> bool: """ Move to a different table. @@ -244,7 +245,7 @@ def next_table(self, report: str = "No more tables") -> bool: :param report: The text to print if there is no next table. :return: True if there is another table to move to. """ - if not self.set_table_index(self.table_index + 1): + if not self._set_table_index(self.table_index + 1): self.print(report) return False return True @@ -257,7 +258,7 @@ def table_metadata(self) -> Table: """Get the metadata of the current table.""" return self.metadata.tables[self.table_name()] - def get_column_names(self) -> list[str]: + def _get_column_names(self) -> list[str]: """Get the names of the current columns.""" return [col.name for col in self.table_metadata().columns] @@ -277,23 +278,25 @@ def report_columns(self) -> None: ], ) - def get_table_config(self, table_name: str) -> dict[str, Any]: + def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: """Get the configuration of the named table.""" ts = self.config.get("tables", None) - if type(ts) is not dict: + if not isinstance(ts, MutableMapping): return {} t = ts.get(table_name) - return t if type(t) is dict else {} + return t if isinstance(t, MutableMapping) else {} - def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: + def set_table_config( + self, table_name: str, config: MutableMapping[str, Any] + ) -> None: """Set the configuration of the named table.""" ts = self.config.get("tables", None) - if type(ts) is not dict: + if not isinstance(ts, MutableMapping): self.config["tables"] = {table_name: config} return ts[table_name] = config - def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: + def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: """Remove all source stats with the given prefix from the configuration.""" src_stats = self.config.get("src-stats", []) new_src_stats = [] @@ -323,7 +326,7 @@ def find_entry_index_by_table_name(self, table_name: str) -> int | None: None, ) - def find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: """Get the table entry of the named table.""" for e in self._table_entries: if e.name == table_name: @@ -339,11 +342,8 @@ def do_counts(self, _arg: str) -> None: colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] with self.sync_engine.connect() as connection: result = connection.execute( - text( - "SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( - table=table_name, - colcounts="".join(colcounts), - ) + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" ) ).first() if result is None: @@ -362,51 +362,52 @@ def do_counts(self, _arg: str) -> None: def do_select(self, arg: str) -> None: """Run a select query over the database and show the first 50 results.""" - MAX_SELECT_ROWS = 50 + max_select_rows = 50 with self.sync_engine.connect() as connection: try: - result = connection.execute(text("SELECT " + arg)) + result = connection.execute(sqlalchemy.text("SELECT " + arg)) except sqlalchemy.exc.DatabaseError as exc: self.print("Failed to execute: {}", exc) return row_count = result.rowcount self.print(self.ROW_COUNT_MSG, row_count) if 50 < row_count: - self.print("Showing the first {} rows", MAX_SELECT_ROWS) + self.print("Showing the first {} rows", max_select_rows) fields = list(result.keys()) - rows = [row._tuple() for row in result.fetchmany(MAX_SELECT_ROWS)] + rows = [row._tuple() for row in result.fetchmany(max_select_rows)] self.print_table(fields, rows) def do_peek(self, arg: str) -> None: """ View some data from the current table. - Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. + Use 'peek col1 col2 col3' to see a sample of values from + columns col1, col2 and col3 in the current table. Use 'peek' to see a sample of the current column(s). Rows that are enitrely null are suppressed. """ - MAX_PEEK_ROWS = 25 + max_peek_rows = 25 if len(self._table_entries) <= self.table_index: return table_name = self.table_name() col_names = arg.split() if not col_names: - col_names = self.get_column_names() + col_names = self._get_column_names() nonnulls = [cn + " IS NOT NULL" for cn in col_names] with self.sync_engine.connect() as connection: - query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( - cols=",".join(col_names), - table=table_name, - where="WHERE" if nonnulls else "", - nonnull=" OR ".join(nonnulls), - max=MAX_PEEK_ROWS, + cols = (",".join(col_names),) + where = ("WHERE" if nonnulls else "",) + nonnull = (" OR ".join(nonnulls),) + query = sqlalchemy.text( + f"SELECT {cols} FROM {table_name} {where} {nonnull}" + f" ORDER BY RANDOM() LIMIT {max_peek_rows}" ) try: - result = connection.execute(text(query)) + result = connection.execute(query) except Exception as exc: self.print(f'SQL query "{query}" caused exception {exc}') return - rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] + rows = [row._tuple() for row in result.fetchmany(max_peek_rows)] self.print_table(list(result.keys()), rows) def complete_peek( @@ -431,7 +432,10 @@ class TableCmdTableEntry(TableEntry): class TableCmd(DbCmd): """Command shell allowing the user to set the type of each table.""" - intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" + intro = ( + "Interactive table configuration (ignore," + " vocabulary, private, generate or empty). Type ? for help.\n" + ) doc_leader = """Use the commands 'ignore', 'vocabulary', 'private', 'empty' or 'generate' to set the table's type. Use 'next' or 'previous' to change table. Use 'tables' and 'columns' for @@ -453,7 +457,9 @@ class TableCmd(DbCmd): NOTE_TEXT_NO_CHANGES = "You have made no changes." NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | None: + def make_table_entry( + self, table_name: str, table: Mapping + ) -> TableCmdTableEntry | None: """ Make a table entry for the named table. @@ -462,14 +468,16 @@ def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | No :return: The newly-constructed table entry. """ if table.get("ignore", False): - return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) + return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) if table.get("vocabulary_table", False): - return TableCmdTableEntry(name, TableType.VOCABULARY, TableType.VOCABULARY) + return TableCmdTableEntry( + table_name, TableType.VOCABULARY, TableType.VOCABULARY + ) if table.get("primary_private", False): - return TableCmdTableEntry(name, TableType.PRIVATE, TableType.PRIVATE) + return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) if table.get("num_rows_per_pass", 1) == 0: - return TableCmdTableEntry(name, TableType.EMPTY, TableType.EMPTY) - return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) + return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) + return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) def __init__( self, @@ -487,9 +495,9 @@ def table_entries(self) -> list[TableCmdTableEntry]: """Get the list of table entries.""" return cast(list[TableCmdTableEntry], self._table_entries) - def find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: """Get the table entry of the table with the given name.""" - entry = super().find_entry_by_table_name(table_name) + entry = super()._find_entry_by_table_name(table_name) if entry is None: return None return cast(TableCmdTableEntry, entry) @@ -556,7 +564,7 @@ def _sanity_check_failures(self) -> list[tuple[str, str, str]]: if from_t == TableType.VOCABULARY: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: - to_entry = self.find_entry_by_table_name(ref) + to_entry = self._find_entry_by_table_name(ref) if ( to_entry is not None and to_entry.new_type != TableType.VOCABULARY @@ -578,7 +586,7 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: if from_t in {TableType.GENERATE, TableType.PRIVATE}: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: - to_entry = self.find_entry_by_table_name(ref) + to_entry = self._find_entry_by_table_name(ref) if to_entry is not None and to_entry.new_type in { TableType.EMPTY, TableType.IGNORE, @@ -640,7 +648,7 @@ def do_next(self, arg: str) -> None: if index is None: self.print(self.ERROR_NO_SUCH_TABLE, arg) return - self.set_table_index(index) + self._set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) @@ -654,7 +662,7 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous table.""" - if not self.set_table_index(self.table_index - 1): + if not self._set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def do_ignore(self, _arg: str) -> None: @@ -676,7 +684,7 @@ def do_private(self, _arg: str) -> None: self.next_table() def do_generate(self, _arg: str) -> None: - """Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table.""" + """Set the current table as to be generated, and go to the next table.""" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() @@ -758,7 +766,7 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: ) with self.sync_engine.connect() as connection: result = connection.execute( - text( + sqlalchemy.text( "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( table=self.table_name(), column=column, @@ -777,7 +785,7 @@ def print_row_data(self, count: int) -> None: """ with self.sync_engine.connect() as connection: result = connection.execute( - text( + sqlalchemy.text( "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( table=self.table_name(), count=count, @@ -879,7 +887,7 @@ def find_missingness_query( return None def make_table_entry( - self, name: str, table: Mapping + self, table_name: str, table_config: Mapping ) -> MissingnessCmdTableEntry | None: """ Make a table entry for a particular table. @@ -888,15 +896,15 @@ def make_table_entry( :param table: The part of ``config.yaml`` relating to this table. :return: The newly-constructed table entry. """ - if table.get("ignore", False): + if table_config.get("ignore", False): return None - if table.get("vocabulary_table", False): + if table_config.get("vocabulary_table", False): return None - if table.get("num_rows_per_pass", 1) == 0: + if table_config.get("num_rows_per_pass", 1) == 0: return None - mgs = table.get("missingness_generators", []) + mgs = table_config.get("missingness_generators", []) old = None - nonnull_columns = self.get_nonnull_columns(name) + nonnull_columns = self.get_nonnull_columns(table_name) if not nonnull_columns: return None if not mgs: @@ -909,7 +917,7 @@ def make_table_entry( elif len(mgs) == 1: mg = mgs[0] mg_name = mg.get("name", None) - if type(mg_name) is str: + if isinstance(mg_name, str): query_comment = self.find_missingness_query(mg) if query_comment is not None: (query, comment) = query_comment @@ -922,7 +930,7 @@ def make_table_entry( if old is None: return None return MissingnessCmdTableEntry( - name=name, + name=table_name, old_type=old, new_type=old, ) @@ -950,11 +958,11 @@ def table_entries(self) -> list[MissingnessCmdTableEntry]: """Get the table entries list.""" return cast(list[MissingnessCmdTableEntry], self._table_entries) - def find_entry_by_table_name( + def _find_entry_by_table_name( self, table_name: str ) -> MissingnessCmdTableEntry | None: """Find the table entry given the table name.""" - entry = super().find_entry_by_table_name(table_name) + entry = super()._find_entry_by_table_name(table_name) if entry is None: return None return cast(MissingnessCmdTableEntry, entry) @@ -965,9 +973,9 @@ def set_prompt(self) -> None: entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] nt = entry.new_type if nt is None: - self.prompt = "(missingness for {0}) ".format(entry.name) + self.prompt = f"(missingness for {entry.name}) " else: - self.prompt = "(missingness for {0}: {1}) ".format(entry.name, nt.name) + self.prompt = f"(missingness for {entry.name}: {nt.name}) " else: self.prompt = "(missingness) " @@ -985,14 +993,12 @@ def _copy_entries(self) -> None: if entry.new_type is None or entry.new_type.name == "none": table.pop("missingness_generators", None) else: - src_stat_key = "missing_auto__{0}__0".format(entry.name) + src_stat_key = f"missing_auto__{entry.name}__0" table["missingness_generators"] = [ { "name": entry.new_type.name, "kwargs": { - "patterns": 'SRC_STATS["{0}"]["results"]'.format( - src_stat_key - ) + "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' }, "columns": entry.new_type.columns, } @@ -1048,7 +1054,7 @@ def do_tables(self, _arg: str) -> None: for entry in self.table_entries: old = "-" if entry.old_type is None else entry.old_type.name new = "-" if entry.new_type is None else entry.new_type.name - desc = new if old == new else "{0}->{1}".format(old, new) + desc = new if old == new else f"{old}->{new}" self.print("{0} {1}", entry.name, desc) def do_next(self, arg: str) -> None: @@ -1066,7 +1072,7 @@ def do_next(self, arg: str) -> None: if index is None: self.print(self.ERROR_NO_SUCH_TABLE, arg) return - self.set_table_index(index) + self._set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) @@ -1080,7 +1086,7 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous table.""" - if not self.set_table_index(self.table_index - 1): + if not self._set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def _set_type(self, name: str, query: str, comment: str) -> None: @@ -1130,7 +1136,10 @@ def do_sampled(self, arg: str) -> None: count, self.get_nonnull_columns(entry.name), ), - f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}", + ( + "The missingness patterns and how often they appear in a" + f" sample of {count} from table {entry.name}" + ), ) self.print("Table {} set to sampled missingness", self.table_name()) self.next_table() @@ -1219,7 +1228,7 @@ class GeneratorCmd(DbCmd): ) def make_table_entry( - self, table_name: str, table: Mapping + self, table_name: str, table_config: Mapping ) -> GeneratorCmdTableEntry | None: """ Make a table entry. @@ -1228,11 +1237,11 @@ def make_table_entry( :param table: The portion of the ``config.yaml`` file describing this table. :return: The newly constructed table entry, or None if this table is to be ignored. """ - if table.get("ignore", False): + if table_config.get("ignore", False): return None - if table.get("vocabulary_table", False): + if table_config.get("vocabulary_table", False): return None - if table.get("num_rows_per_pass", 1) == 0: + if table_config.get("num_rows_per_pass", 1) == 0: return None metadata_table = self.metadata.tables[table_name] columns = [str(colname) for colname in metadata_table.columns.keys()] @@ -1241,7 +1250,7 @@ def make_table_entry( new_generator_infos: list[GeneratorInfo] = [] old_generator_infos: list[GeneratorInfo] = [] - for rg in table.get("row_generators", []): + for rg in table_config.get("row_generators", []): gen_name = rg.get("name", None) if gen_name: ca = rg.get("columns_assigned", []) @@ -1310,6 +1319,7 @@ def __init__( :param config: Configuration loaded from ``config.yaml`` """ super().__init__(src_dsn, src_schema, metadata, config) + self.generators: list[Generator] | None = None self.generator_index = 0 self.generators_valid_columns: Optional[tuple[int, list[str]]] = None self.set_prompt() @@ -1319,7 +1329,7 @@ def table_entries(self) -> list[GeneratorCmdTableEntry]: """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" return cast(list[GeneratorCmdTableEntry], self._table_entries) - def find_entry_by_table_name( + def _find_entry_by_table_name( self, table_name: str ) -> GeneratorCmdTableEntry | None: """ @@ -1328,30 +1338,30 @@ def find_entry_by_table_name( :param table_name: The name of the table to find. :return: The table entry, or None if no such table name exists. """ - entry = super().find_entry_by_table_name(table_name) + entry = super()._find_entry_by_table_name(table_name) if entry is None: return None return cast(GeneratorCmdTableEntry, entry) - def set_table_index(self, index: int) -> bool: + def _set_table_index(self, index: int) -> bool: """ Move to a new table. :param index: table index to move to. """ - ret = super().set_table_index(index) + ret = super()._set_table_index(index) if ret: self.generator_index = 0 self.set_prompt() return ret - def previous_table(self) -> bool: + def _previous_table(self) -> bool: """ Move to the table before the current one. :return: True if there is a previous table to go to. """ - ret = self.set_table_index(self.table_index - 1) + ret = self._set_table_index(self.table_index - 1) if ret: table = self.get_table() if table is None: @@ -1371,7 +1381,7 @@ def get_table(self) -> GeneratorCmdTableEntry | None: return self.table_entries[self.table_index] return None - def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: """Get a pair; the table name then the generator information.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] @@ -1380,26 +1390,26 @@ def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: return (entry.name, None) return (None, None) - def get_column_names(self) -> list[str]: + def _get_column_names(self) -> list[str]: """Get the (unqualified) names for all the current columns.""" - (_, generator_info) = self.get_table_and_generator() + (_, generator_info) = self._get_table_and_generator() return generator_info.columns if generator_info else [] - def column_metadata(self) -> list[Column]: + def _column_metadata(self) -> list[Column]: """Get the metadata for all the current columns.""" table = self.table_metadata() if table is None: return [] - return [table.columns[name] for name in self.get_column_names()] + return [table.columns[name] for name in self._get_column_names()] def set_prompt(self) -> None: """Set the prompt according to the current table, column and generator.""" - (table_name, gen_info) = self.get_table_and_generator() + (table_name, gen_info) = self._get_table_and_generator() if table_name is None: self.prompt = "(generators) " return if gen_info is None: - self.prompt = "({table}) ".format(table=table_name) + self.prompt = f"({table_name}) " return table = self.table_metadata() columns = [ @@ -1408,7 +1418,7 @@ def set_prompt(self) -> None: gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - def _remove_auto_src_stats(self) -> list[dict[str, Any]]: + def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: """ Remove all automatic source stats. @@ -1510,7 +1520,7 @@ def do_quit(self, arg: str) -> bool: return True return False - def do_tables(self, arg: str) -> None: + def do_tables(self, _arg: str) -> None: """List the tables.""" for t_entry in self.table_entries: entry = cast(GeneratorCmdTableEntry, t_entry) @@ -1518,7 +1528,7 @@ def do_tables(self, arg: str) -> None: how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" self.print("{0} ({1})", entry.name, how_many) - def do_list(self, arg: str) -> None: + def do_list(self, _arg: str) -> None: """List the generators in the current table.""" if len(self.table_entries) <= self.table_index: self.print("Error: no table {0}", self.table_index) @@ -1547,7 +1557,7 @@ def do_columns(self, _arg: str) -> None: def do_info(self, _arg: str) -> None: """Show information about the current column.""" - for cm in self.column_metadata(): + for cm in self._column_metadata(): self.print( "Column {0} in table {1} has type {2} ({3}).", cm.name, @@ -1614,7 +1624,7 @@ def go_to(self, target: str) -> bool: if gen_index is None: self.print("we cannot set the generator for column {0}", parts[1]) return False - self.set_table_index(table_index) + self._set_table_index(table_index) if gen_index is not None: self.generator_index = gen_index self.set_prompt() @@ -1694,7 +1704,7 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous generator.""" if self.generator_index == 0: - self.previous_table() + self._previous_table() else: self.generator_index -= 1 self.set_prompt() @@ -1707,7 +1717,7 @@ def _generators_valid(self) -> bool: """Test if ``self.generators`` is still correct for the current columns.""" return self.generators_valid_columns == ( self.table_index, - self.get_column_names(), + self._get_column_names(), ) def _get_generator_proposals(self) -> list[Generator]: @@ -1715,13 +1725,13 @@ def _get_generator_proposals(self) -> list[Generator]: if not self._generators_valid(): self.generators = None if self.generators is None: - columns = self.column_metadata() + columns = self._column_metadata() gens = everything_factory().get_generators(columns, self.sync_engine) sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) self.generators = sorted_gens self.generators_valid_columns = ( self.table_index, - self.get_column_names().copy(), + self._get_column_names().copy(), ) return self.generators @@ -1762,7 +1772,7 @@ def do_compare(self, arg: str) -> None: for argument in args: if argument.isdigit(): n = int(argument) - if 0 < n and n <= len(gens): + if 0 < n <= len(gens): gen = gens[n - 1] comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) self._print_values_queried(table_name, n, gen) @@ -1822,7 +1832,7 @@ def _print_custom_queries(self, gen: Generator) -> None: def _get_custom_queries_from( self, out: dict[str, Any], nominal: Any, actual: Any ) -> None: - if type(nominal) is str: + if isinstance(nominal, str): src_stat_groups = self.SRC_STAT_RE.search(nominal) # Do we have a SRC_STAT reference? if src_stat_groups: @@ -1834,10 +1844,10 @@ def _get_custom_queries_from( actual = {sub: actual} else: out[cq_key] = actual - elif type(nominal) is list and type(actual) is list: + elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): for i in range(min(len(nominal), len(actual))): self._get_custom_queries_from(out, nominal[i], actual[i]) - elif type(nominal) is dict and type(actual) is dict: + elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): for k, v in nominal.items(): if k in actual: self._get_custom_queries_from(out, v, actual[k]) @@ -1892,12 +1902,12 @@ def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None def _get_column_data( self, count: int, to_str: Callable[[Any], str] = repr ) -> list[list[str]]: - columns = self.get_column_names() + columns = self._get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) with self.sync_engine.connect() as connection: result = connection.execute( - text( + sqlalchemy.text( f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" ) ) @@ -1977,7 +1987,7 @@ def do_set(self, arg: str) -> None: def set_generator(self, gen: Generator | None) -> None: """Set the current column's generator.""" - (table, gen_info) = self.get_table_and_generator() + (table, gen_info) = self._get_table_and_generator() if table is None: self.print("Error: no table") return @@ -2155,7 +2165,7 @@ def update_config_generators( if line: if len(line) != 3: logger.error( - "line {0} of file {1} does not have three values", + "line %d of file %s does not have three values", line_no, spec_path, ) diff --git a/datafaker/main.py b/datafaker/main.py index c1299830..cf7bd3bf 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -303,7 +303,6 @@ def make_vocab( @app.command() def make_stats( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: str = Option(STATS_FILENAME), force: bool = Option( @@ -324,13 +323,12 @@ def make_stats( _check_file_non_existence(stats_file_path) config = read_config_file(config_file) if config_file is not None else {} - orm_metadata = load_metadata(orm_file, config) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(src_dsn, config, orm_metadata, settings.src_schema) + make_src_stats(src_dsn, config, settings.src_schema) ) stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8") logger.debug("%s created.", stats_file) diff --git a/datafaker/make.py b/datafaker/make.py index 096ee8bf..bca7a2b7 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -139,15 +139,15 @@ class StoryGeneratorInfo: def _render_value(v: Any) -> str: - if type(v) is list: + if isinstance(v, list): return "[" + ", ".join(_render_value(x) for x in v) + "]" - if type(v) is set: + if isinstance(v, set): return "{" + ", ".join(_render_value(x) for x in v) + "}" - if type(v) is dict: + if isinstance(v, dict): return ( "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" ) - if type(v) is str: + if isinstance(v, str): return v return str(v) @@ -603,8 +603,10 @@ def make_table_generators( # pylint: disable=too-many-locals Args: metadata: database ORM config: Configuration to control the generator creation. - orm_filename: "orm.yaml" file path so that the generator file can load the MetaData object - config_filename: "config.yaml" file path so that the generator file can load the MetaData object + orm_filename: "orm.yaml" file path so that the generator + file can load the MetaData object + config_filename: "config.yaml" file path so that the generator + file can load the MetaData object src_stats_filename: A filename for where to read src stats from. Optional, if `None` this feature will be skipped overwrite_files: Whether to overwrite pre-existing vocabulary files @@ -765,7 +767,8 @@ async def __aexit__( """Exit the ``with`` section, closing the connection.""" if isinstance(self._connection, AsyncConnection): await self._connection.close() - self._connection.close() + else: + self._connection.close() async def execute_raw_query(self, query: Executable) -> CursorResult: """Execute the query on the owned connection.""" @@ -808,7 +811,7 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: def fix_type(value: Any) -> Any: """Make this value suitable for yaml output.""" - if type(value) is decimal.Decimal: + if isinstance(value, decimal.Decimal): return float(value) return value @@ -819,37 +822,34 @@ def fix_types(dics: list[dict]) -> list[dict]: async def make_src_stats( - dsn: str, config: Mapping, metadata: MetaData, schema_name: Optional[str] = None + dsn: str, config: Mapping, schema_name: Optional[str] = None ) -> dict[str, dict[str, Any]]: - """Run the src-stats queries specified by the configuration. + """ + Run the src-stats queries specified by the configuration. Query the src database with the queries in the src-stats block of the `config` dictionary, using the differential privacy parameters set in the `smartnoise-sql` block of `config`. Record the results in a dictionary and return it. - Args: - dsn: database connection string - config: a dictionary with the necessary configuration - metadata: the database ORM - schema_name: name of the database schema - Returns: - The dictionary of src-stats. + :param dsn: database connection string + :param config: a dictionary with the necessary configuration + :param schema_name: name of the database schema + :return: The dictionary of src-stats. """ use_asyncio = config.get("use-asyncio", False) engine = create_db_engine(dsn, schema_name=schema_name, use_asyncio=use_asyncio) async with DbConnection(engine) as db_conn: - return await make_src_stats_connection(config, db_conn, metadata) + return await make_src_stats_connection(config, db_conn) async def make_src_stats_connection( - config: Mapping, db_conn: DbConnection, metadata: MetaData + config: Mapping, db_conn: DbConnection ) -> dict[str, dict[str, Any]]: """ Make the ``src-stats.yaml`` file given the database connection to read from. :param config: configuration from ``config.yaml``. :param db_conn: Source database connection. - :param metadata: Source database metadata from ``orm.yaml``. """ date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) diff --git a/datafaker/utils.py b/datafaker/utils.py index 61089ab3..6d0041de 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -6,19 +6,10 @@ import json import logging import sys +from collections.abc import Mapping, Sequence from pathlib import Path from types import ModuleType -from typing import ( - Any, - Callable, - Final, - Generator, - Iterable, - Mapping, - Optional, - TypeVar, - Union, -) +from typing import Any, Callable, Final, Generator, Iterable, Optional, TypeVar, Union import sqlalchemy import yaml @@ -119,6 +110,7 @@ def table_row_count(table: Table, conn: Connection) -> int: :return: The number of rows in the table. """ return conn.execute( + # pylint: disable=not-callable select(sqlalchemy.func.count()).select_from( sqlalchemy.table( table.name, @@ -527,7 +519,7 @@ def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any, None, None]: if not line or line.startswith("-"): if buf: yl = yaml.load(buf, yaml.Loader) - assert type(yl) is list and len(yl) == 1 + assert isinstance(yl, Sequence) and len(yl) == 1 yield yl[0] if not line: return @@ -602,7 +594,7 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: - logger.warning(f"Cycle detected between tables: {cycle}") + logger.warning("Cycle detected between tables: %s", cycle) return [metadata.tables[tn] for tn in sorted_tables] @@ -652,7 +644,7 @@ def generators_require_stats(config: Mapping) -> bool: names = ( node.id for node in ast.walk(ast.parse(arg)) - if type(node) is ast.Name + if isinstance(node, ast.Name) ) if any(name == "SRC_STATS" for name in names): stats_required = True @@ -668,12 +660,12 @@ def generators_require_stats(config: Mapping) -> bool: ) ) for k, arg in call.get("kwargs", {}).items(): - if type(arg) is str: + if isinstance(arg, str): try: names = ( node.id for node in ast.walk(ast.parse(arg)) - if type(node) is ast.Name + if isinstance(node, ast.Name) ) if any(name == "SRC_STATS" for name in names): stats_required = True diff --git a/tests/test_functional.py b/tests/test_functional.py index 792dc40a..ac7e51a7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -81,6 +81,13 @@ def tearDown(self) -> None: os.chdir(self.start_dir) super().tearDown() + def assert_silent_success(self, completed_process: Result) -> None: + """Assert that the process completed successfully without producing output.""" + self.assertNoException(completed_process) + self.assertSuccess(completed_process) + self.assertEqual(completed_process.stderr, "") + self.assertEqual(completed_process.stdout, "") + def test_workflow_minimal_args(self) -> None: """Test the recommended CLI workflow runs without errors.""" shutil.copy(self.config_file_path, "config.yaml") @@ -88,26 +95,19 @@ def test_workflow_minimal_args(self) -> None: "make-tables", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stderr, "") - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-vocab", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-stats", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "create-generators", @@ -138,27 +138,18 @@ def test_workflow_minimal_args(self) -> None: completed_process = self.invoke( "create-tables", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke( "create-vocab", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-stats", "--force", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke("create-data") self.assertNoException(completed_process) @@ -514,7 +505,6 @@ def test_unique_constraint_fail(self) -> None: "make-stats", f"--stats-file={self.stats_file_path}", f"--config-file={self.config_file_path}", - f"--orm-file={self.alt_orm_file_path}", "--force", ) self.invoke( diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 1082c337..fcb5ced2 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -4,9 +4,10 @@ import re from dataclasses import dataclass from typing import Any, Iterable, MutableMapping +from unittest import TestCase from unittest.mock import MagicMock, Mock, patch -from sqlalchemy import insert, select +from sqlalchemy import Connection, MetaData, insert, select from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.interactive import ( @@ -20,6 +21,8 @@ class TestDbCmdMixin(DbCmd): + """A mixin for capturing output from interactive commands.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize a TestDbCmdMixin""" super().__init__(*args, **kwargs) @@ -46,10 +49,11 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: """Capture the printed table.""" self.columns = columns - def columnize(self, list: list[str] | None, _displaywidth: int = 80) -> None: + # pylint: disable=arguments-renamed + def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: """Capture the printed table.""" - if list is not None: - self.column_items.append(list) + if items is not None: + self.column_items.append(items) def ask_save(self) -> str: """Quitting always works without needing to ask the user.""" @@ -666,11 +670,11 @@ def test_weighted_choice_generator_generates_choices(self) -> None: gc.do_propose("") proposals = gc.get_proposals() gen_proposal = proposals[generator] - self.assertSubset(set(gen_proposal[2]), {str(v) for v in values}) + self.assert_subset(set(gen_proposal[2]), {str(v) for v in values}) gc.do_compare(str(gen_proposal[0])) col_heading = f"{gen_proposal[0]}. {generator}" self.assertIn(col_heading, gc.columns) - self.assertSubset(set(gc.columns[col_heading]), values) + self.assert_subset(set(gc.columns[col_heading]), values) def test_merge_columns(self) -> None: """Test that we can merge columns and set a multivariate generator""" @@ -822,21 +826,16 @@ def test_aggregate_queries_merge(self) -> None: Test that we can set a generator that requires select aggregate clauses and keep an old one, resulting in a merged query. """ - config = { - "tables": { - "string": { - "row_generators": [ - { - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }, - } - ] - } + rg = { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', }, + } + config = { + "tables": {"string": {"row_generators": [rg]}}, "src-stats": [ { "name": "auto__string", @@ -1096,7 +1095,6 @@ def test_create_with_choice(self) -> None: def test_create_with_weighted_choice(self) -> None: """Smoke test weighted choice.""" - table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") gc.reset() @@ -1108,7 +1106,7 @@ def test_create_with_weighted_choice(self) -> None: "dist_gen.weighted_choice [sampled and suppressed]", proposals ) prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] - self.assertSubset(set(prop[2]), {"1", "4"}) + self.assert_subset(set(prop[2]), {"1", "4"}) gc.reset() gc.do_compare(str(prop[0])) col_heading = ( @@ -1116,7 +1114,7 @@ def test_create_with_weighted_choice(self) -> None: ) self.assertIn(col_heading, set(gc.columns.keys())) col_set: set[int] = set(gc.columns[col_heading]) - self.assertSubset(col_set, {1, 4}) + self.assert_subset(col_set, {1, 4}) gc.do_set(str(prop[0])) gc.do_next("number_table.two") gc.reset() @@ -1128,13 +1126,13 @@ def test_create_with_weighted_choice(self) -> None: "dist_gen.weighted_choice [sampled and suppressed]", proposals ) prop = proposals["dist_gen.weighted_choice"] - self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) + self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.reset() gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice" self.assertIn(col_heading, set(gc.columns.keys())) col_set2: set[int] = set(gc.columns[col_heading]) - self.assertSubset(col_set2, {1, 2, 3, 4, 5}) + self.assert_subset(col_set2, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_next("number_table.three") gc.reset() @@ -1146,22 +1144,22 @@ def test_create_with_weighted_choice(self) -> None: "dist_gen.weighted_choice [sampled and suppressed]", proposals ) prop = proposals["dist_gen.weighted_choice [sampled]"] - self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) + self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled]" self.assertIn(col_heading, set(gc.columns.keys())) col_set3: set[int] = set(gc.columns[col_heading]) - self.assertSubset(col_set3, {1, 2, 3, 4, 5}) + self.assert_subset(col_set3, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() ones = set() twos = set() threes = set() - for row in rows: + for row in conn.execute( + select(self.metadata.tables["number_table"]) + ).fetchall(): ones.add(row.one) twos.add(row.two) threes.add(row.three) @@ -1447,6 +1445,60 @@ def covar(self) -> float: return (self.xy - self.x * self.y / self.n) / (self.n - 1) +class EavMeasurementTableStats: + """The statistics for the Measurement table of eav.sql.""" + + def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: + stmt = select(metadata.tables["measurement"]) + rows = conn.execute(stmt).fetchall() + self.types: set[int] = set() + self.one_count = 0 + self.one_yes_count = 0 + self.two = Correlation() + self.three = Correlation() + self.four = Correlation() + self.fish = Stat() + self.fowl = Stat() + for row in rows: + self.types.add(row.type) + if row.type == 1: + # yes or no + test.assertIsNone(row.first_value) + test.assertIsNone(row.second_value) + test.assertIn(row.third_value, {"yes", "no"}) + self.one_count += 1 + if row.third_value == "yes": + self.one_yes_count += 1 + elif row.type == 2: + # positive correlation around 1.4, 1.8 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.two.add2(row.first_value, row.second_value) + elif row.type == 3: + # negative correlation around 11.8, 12.1 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.three.add2(row.first_value, row.second_value) + elif row.type == 4: + # positive correlation around 21.4, 23.4 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.four.add2(row.first_value, row.second_value) + elif row.type == 5: + test.assertIn(row.third_value, {"fish", "fowl"}) + test.assertIsNotNone(row.first_value) + test.assertIsNone(row.second_value) + if row.third_value == "fish": + # mean 8.1 and sd 0.755 + self.fish.add(row.first_value) + else: + # mean 11.2 and sd 1.114 + self.fowl.add(row.first_value) + + class NullPartitionedTests(GeneratesDBTestCase): """Testing null-partitioned grouped multivariate generation.""" @@ -1466,7 +1518,6 @@ def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: def test_create_with_null_partitioned_grouped_multivariate(self) -> None: """Test EAV for all columns.""" - table_name = "measurement" generate_count = 800 with self._get_cmd({}) as gc: gc.do_next("measurement.type") @@ -1502,96 +1553,49 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: conn.commit() self.create_data(gc.config, num_passes=generate_count) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - one_count = 0 - one_yes_count = 0 - two = Correlation() - three = Correlation() - four = Correlation() - fish = Stat() - fowl = Stat() - for row in rows: - if row.type == 1: - # yes or no - self.assertIsNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {"yes", "no"}) - one_count += 1 - if row.third_value == "yes": - one_yes_count += 1 - elif row.type == 2: - # positive correlation around 1.4, 1.8 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - two.add2(row.first_value, row.second_value) - elif row.type == 3: - # negative correlation around 11.8, 12.1 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - three.add2(row.first_value, row.second_value) - elif row.type == 4: - # positive correlation around 21.4, 23.4 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - four.add2(row.first_value, row.second_value) - elif row.type == 5: - self.assertIn(row.third_value, {"fish", "fowl"}) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - if row.third_value == "fish": - # mean 8.1 and sd 0.755 - fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - fowl.add(row.first_value) - # type 1 - self.assertAlmostEqual( - one_count, generate_count * 5 / 20, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - one_yes_count / one_count, 0.4, delta=generate_count * 0.4 - ) - # type 2 - self.assertAlmostEqual( - two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 - ) - self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) - self.assertAlmostEqual(two.x_var(), 0.315, delta=0.18) - self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(two.y_var(), 0.105, delta=0.06) - self.assertAlmostEqual(two.covar(), 0.105, delta=0.07) - # type 3 - self.assertAlmostEqual( - three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(three.covar(), -2.085, delta=1.1) - # type 4 - self.assertAlmostEqual( - four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(four.covar(), 3.33, delta=1) - # type 5/fish - self.assertAlmostEqual( - fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.6) - # type 5/fowl - self.assertAlmostEqual( - fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) + stats = EavMeasurementTableStats(conn, self.metadata, self) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 2 + self.assertAlmostEqual( + stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) + self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.6) + self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) + self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) + # type 3 + self.assertAlmostEqual( + stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) + # type 4 + self.assertAlmostEqual( + stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: """Test EAV for all columns with sampled and suppressed generation.""" - table_name = "measurement" - table2_name = "observation" generate_count = 800 with self._get_cmd({}) as gc: gc.do_next("measurement.type") @@ -1646,61 +1650,12 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No conn.commit() self.create_data(gc.config, num_passes=generate_count) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - one_count = 0 - one_yes_count = 0 - fish = Stat() - fowl = Stat() - types: set[int] = set() - for row in rows: - types.add(row.type) - if row.type == 1: - # yes or no - self.assertIsNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {"yes", "no"}) - if row.third_value == "yes": - one_yes_count += 1 - one_count += 1 - elif row.type == 5: - self.assertIn(row.third_value, {"fish", "fowl"}) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - if row.third_value == "fish": - # mean 8.1 and sd 0.755 - fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - fowl.add(row.first_value) - self.assertSubset(types, {1, 2, 3, 4, 5}) - self.assertEqual(len(types), 4) - self.assertSubset({1, 5}, types) - # type 1 - self.assertAlmostEqual( - one_count, generate_count * 5 / 11, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - one_yes_count / one_count, 0.4, delta=generate_count * 0.4 - ) - # type 5/fish - self.assertAlmostEqual( - fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.5) - # type 5/fowl - self.assertAlmostEqual( - fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) - stmt = select(self.metadata.tables[table2_name]) + stats = EavMeasurementTableStats(conn, self.metadata, self) + stmt = select(self.metadata.tables["observation"]) rows = conn.execute(stmt).fetchall() firsts = Stat() for row in rows: - types.add(row.type) + stats.types.add(row.type) self.assertEqual(row.type, 1) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) @@ -1708,6 +1663,29 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No firsts.add(row.first_value) self.assertEqual(firsts.count(), 800) self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) + self.assert_subset(stats.types, {1, 2, 3, 4, 5}) + self.assertEqual(len(stats.types), 4) + self.assert_subset({1, 5}, stats.types) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) class NonInteractiveTests(RequiresDBTestCase): diff --git a/tests/test_main.py b/tests/test_main.py index a570fe4d..e318f1e5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -48,6 +48,7 @@ def test_create_vocab( @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators( self, mock_require_stats: MagicMock, @@ -89,6 +90,7 @@ def test_create_generators( @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_uses_default_stats_file_if_necessary( self, mock_require_stats: MagicMock, @@ -151,6 +153,7 @@ def test_create_generators_errors_if_file_exists( @patch("datafaker.main.get_settings") @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_with_force_enabled( self, mock_make: MagicMock, @@ -371,10 +374,8 @@ def test_make_tables_with_force_enabled( @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.main.get_settings") - @patch("datafaker.main.load_metadata", side_effect=["ms"]) def test_make_stats( self, - _lm: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock, @@ -398,7 +399,7 @@ def test_make_stats( with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) mock_make.assert_called_once_with( - get_test_settings().src_dsn, config, "ms", None + get_test_settings().src_dsn, config, None ) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" @@ -452,12 +453,10 @@ def test_make_stats_errors_if_no_src_dsn(self, mock_logger: MagicMock) -> None: @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.main.get_settings") - @patch("datafaker.main.load_metadata") def test_make_stats_with_force_enabled( self, - mock_meta: MagicMock, mock_get_settings: MagicMock, - mock_make: MagicMock, + mock_make_src_stats: MagicMock, mock_path: MagicMock, ) -> None: """Tests that the make-stats command overwrite files when instructed.""" @@ -469,7 +468,7 @@ def test_make_stats_with_force_enabled( test_settings: Settings = get_test_settings() mock_get_settings.return_value = test_settings make_test_output: dict = {"some_stat": 0} - mock_make.return_value = make_test_output + mock_make_src_stats.return_value = make_test_output for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): @@ -479,23 +478,21 @@ def test_make_stats_with_force_enabled( "make-stats", "--stats-file=stats_file.yaml", f"--config-file={test_config_file}", - "--orm-file=tests/examples/example_config.yaml", force_option, ], ) - mock_make.assert_called_once_with( + mock_make_src_stats.assert_called_once_with( test_settings.src_dsn, config_file_content, - mock_meta.return_value, - None, + test_settings.src_schema, ) mock_path.return_value.write_text.assert_called_once_with( "some_stat: 0\n", encoding="utf-8" ) self.assertSuccess(result) - mock_make.reset_mock() + mock_make_src_stats.reset_mock() mock_path.reset_mock() def test_validate_config(self) -> None: diff --git a/tests/test_make.py b/tests/test_make.py index b522778f..49bb9e71 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -170,14 +170,14 @@ def check_make_stats_output(self, src_stats: dict) -> None: def test_make_stats_no_asyncio_schema(self) -> None: """Test that make_src_stats works when explicitly naming a schema.""" src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, self.config, self.metadata, self.schema_name) + make_src_stats(self.dsn, self.config, self.schema_name) ) self.check_make_stats_output(src_stats) def test_make_stats_no_asyncio(self) -> None: """Test that make_src_stats works using the example configuration.""" src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, self.config, self.metadata, self.schema_name) + make_src_stats(self.dsn, self.config, self.schema_name) ) self.check_make_stats_output(src_stats) @@ -189,7 +189,7 @@ def test_make_stats_asyncio(self) -> None: asyncio.set_event_loop(loop) config_asyncio = {**self.config, "use-asyncio": True} src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, config_asyncio, self.metadata, self.schema_name) + make_src_stats(self.dsn, config_asyncio, self.schema_name) ) self.check_make_stats_output(src_stats) @@ -220,7 +220,7 @@ def test_make_stats_empty_result(self, mock_logger: MagicMock) -> None: ] } src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, config, self.metadata, self.schema_name) + make_src_stats(self.dsn, config, self.schema_name) ) self.assertEqual(src_stats[query_name1]["results"], []) self.assertEqual(src_stats[query_name2]["results"], []) diff --git a/tests/test_providers.py b/tests/test_providers.py index b5437833..cd880072 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,9 +1,8 @@ """Tests for the providers module.""" import datetime as dt -from pathlib import Path from typing import Any -from sqlalchemy import Column, Integer, Text, create_engine, insert +from sqlalchemy import Column, Integer, Text, insert from sqlalchemy.ext.declarative import declarative_base from datafaker import providers diff --git a/tests/test_remove.py b/tests/test_remove.py index a6dbb85c..0d466db7 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -20,6 +20,7 @@ class RemoveThingsTestCase(RequiresDBTestCase): def count_rows(self, connection: Connection, table_name: str) -> int | None: """Count the rows in a table.""" return connection.execute( + # pylint: disable=not-callable. select(func.count()).select_from(self.metadata.tables[table_name]) ).scalar() @@ -40,8 +41,8 @@ def test_remove_data(self, mock_get_settings: MagicMock) -> None: }, ) with self.sync_engine.connect() as conn: - self.assertGreaterAndNotNone(self.count_rows(conn, "manufacturer"), 0) - self.assertGreaterAndNotNone(self.count_rows(conn, "model"), 0) + self.assert_greater_and_not_none(self.count_rows(conn, "manufacturer"), 0) + self.assert_greater_and_not_none(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) self.assertEqual(self.count_rows(conn, "string"), 0) self.assertEqual(self.count_rows(conn, "signature_model"), 0) diff --git a/tests/test_rst.py b/tests/test_rst.py index 29bf971e..1a57ed61 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -55,7 +55,7 @@ def test_dir(self) -> None: for file_error in file_errors # Only worry about ERRORs and WARNINGs if file_error.level <= 2 - if not any(filter(lambda m: m in file_error.full_message, allowed_errors)) + if not any(m in file_error.full_message for m in allowed_errors) ] if filtered_errors: diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 503a36f5..afec078c 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -1,5 +1,4 @@ """Tests for the unique_generator module.""" -from pathlib import Path from unittest.mock import MagicMock from sqlalchemy import Boolean, Column, Integer, Text, UniqueConstraint, insert diff --git a/tests/utils.py b/tests/utils.py index 4e9f2365..78df87cc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,7 +12,6 @@ import testing.postgresql import yaml -from sqlalchemy import MetaData from sqlalchemy.schema import MetaData from datafaker import settings @@ -80,7 +79,7 @@ def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-nam return self.fail("".join(traceback.format_exception(result.exception))) - def assertGreaterAndNotNone(self, left: float | None, right: float) -> None: + def assert_greater_and_not_none(self, left: float | None, right: float) -> None: """ Assert left is not None and greater than right """ @@ -89,7 +88,7 @@ def assertGreaterAndNotNone(self, left: float | None, right: float) -> None: else: self.assertGreater(left, right) - def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: + def assert_subset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: """Assert a set is a (non-strict) subset. :param set1: The asserted subset. @@ -100,9 +99,9 @@ def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> No try: difference = set1.difference(set2) except TypeError as e: - self.fail("invalid type when attempting set difference: %s" % e) + self.fail(f"invalid type when attempting set difference: {e}") except AttributeError as e: - self.fail("first argument does not support set difference: %s" % e) + self.fail(f"first argument does not support set difference: {e}") if not difference: return @@ -113,8 +112,8 @@ def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> No for item in difference: lines.append(repr(item)) - standardMsg = "\n".join(lines) - self.fail(self._formatMessage(msg, standardMsg)) + standard_msg = "\n".join(lines) + self.fail(self._formatMessage(msg, standard_msg)) @skipUnless(shutil.which("psql"), "need to find 'psql': install PostgreSQL to enable") @@ -148,7 +147,7 @@ def tearDownClass(cls) -> None: def setUp(self) -> None: super().setUp() assert self.Postgresql is not None - self.postgresql = self.Postgresql() + self.postgresql = self.Postgresql() # pylint: disable=not-callable if self.dump_file_path is not None: self.run_psql(Path(self.examples_dir) / Path(self.dump_file_path)) self.engine = create_db_engine( @@ -166,11 +165,12 @@ def tearDown(self) -> None: @property def dsn(self) -> str: + """Get the database connection string.""" if self.database_name: url = self.postgresql.url(database=self.database_name) else: url = self.postgresql.url() - assert type(url) is str + assert isinstance(url, str) return url def run_psql(self, dump_file: Path) -> None: @@ -199,7 +199,19 @@ def run_psql(self, dump_file: Path) -> None: class GeneratesDBTestCase(RequiresDBTestCase): + """A test case for which a database is generated.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialise a GeneratedDB test case.""" + super().__init__(*args, **kwargs) + self.generators_file_path = "" + self.stats_fd = 0 + self.stats_file_path = "" + self.config_file_path = "" + self.config_fd = 0 + def setUp(self) -> None: + """Set up the test case with an actual orm.yaml file.""" super().setUp() # Generate the `orm.yaml` from the database (self.orm_fd, self.orm_file_path) = mkstemp(".yaml", "orm_", text=True) @@ -207,21 +219,20 @@ def setUp(self) -> None: orm_fh.write(make_tables_file(self.dsn, self.schema_name, {})) def set_configuration(self, config: Mapping[str, Any]) -> None: - """ - Accepts a configuration file, writes it out. - """ + """Accepts a configuration file, writes it out.""" (self.config_fd, self.config_file_path) = mkstemp(".yaml", "config_", text=True) with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: """ - Runs `make-stats` producing `src-stats.yaml` + Runs `make-stats` producing `src-stats.yaml`. + :return: Python dictionary representation of the contents of the src-stats file """ loop = asyncio.new_event_loop() src_stats = loop.run_until_complete( - make_src_stats(self.dsn, config, self.metadata, self.schema_name) + make_src_stats(self.dsn, config, self.schema_name) ) loop.close() (self.stats_fd, self.stats_file_path) = mkstemp( From 2894044f660f2cdc01ce9a735b47ff20b4ad6f80 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 13 Oct 2025 18:56:46 +0100 Subject: [PATCH 20/44] Precommit clean! --- .pre-commit-config.yaml | 1 + datafaker/base.py | 53 +- datafaker/create.py | 19 +- datafaker/generators.py | 2160 ---------------- datafaker/generators/__init__.py | 53 + datafaker/generators/base.py | 417 ++++ datafaker/generators/choice.py | 398 +++ datafaker/generators/continuous.py | 471 ++++ datafaker/generators/mimesis.py | 418 ++++ datafaker/generators/partitioned.py | 514 ++++ datafaker/interactive.py | 2175 ----------------- datafaker/interactive/__init__.py | 95 + datafaker/interactive/base.py | 404 +++ datafaker/interactive/generators.py | 980 ++++++++ datafaker/interactive/missingness.py | 355 +++ datafaker/interactive/table.py | 376 +++ datafaker/main.py | 5 +- datafaker/utils.py | 11 +- ...tive.py => test_interactive_generators.py} | 539 +--- tests/test_interactive_missingness.py | 100 + tests/test_interactive_table.py | 398 +++ tests/test_main.py | 4 +- tests/utils.py | 49 +- 23 files changed, 5090 insertions(+), 4905 deletions(-) delete mode 100644 datafaker/generators.py create mode 100644 datafaker/generators/__init__.py create mode 100644 datafaker/generators/base.py create mode 100644 datafaker/generators/choice.py create mode 100644 datafaker/generators/continuous.py create mode 100644 datafaker/generators/mimesis.py create mode 100644 datafaker/generators/partitioned.py delete mode 100644 datafaker/interactive.py create mode 100644 datafaker/interactive/__init__.py create mode 100644 datafaker/interactive/base.py create mode 100644 datafaker/interactive/generators.py create mode 100644 datafaker/interactive/missingness.py create mode 100644 datafaker/interactive/table.py rename tests/{test_interactive.py => test_interactive_generators.py} (70%) create mode 100644 tests/test_interactive_missingness.py create mode 100644 tests/test_interactive_table.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7eba811b..04464f9d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,7 @@ repos: rev: v4.2.0 hooks: - id: trailing-whitespace + exclude: docs/(source|build/html)/_static/ - id: end-of-file-fixer exclude: docs/source/_static/ - id: check-yaml diff --git a/datafaker/base.py b/datafaker/base.py index 6ff1890e..f75591c3 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -26,6 +26,10 @@ ) +class InappropriateGeneratorException(Exception): + """Exception thrown if a generator is requested that is not appropriate.""" + + @functools.cache def zipf_weights(size: int) -> list[float]: """Get the weights of a Zipf distribution of a given size.""" @@ -122,19 +126,26 @@ def lognormal(self, logmean: float, logsd: float) -> float: """ return random.lognormvariate(float(logmean), float(logsd)) - def choice(self, a: list[T]) -> T: + def choice_direct(self, a: list[T]) -> T: + """ + Choose a value with equal probability. + + :param a: The list of values to output. + :return: The chosen value. + """ + return random.choice(a) + + def choice(self, a: list[Mapping[str, T]]) -> T | None: """ Choose a value with equal probability. - :param a: The list of values to output. Each element is either - the value itself, or a mapping with a key ``value`` and the key - is the value to return. + :param a: The list of values to output. Each element is a mapping with + a key ``value`` and the key is the value to return. :return: The chosen value. """ - c = random.choice(a) - return c["value"] if isinstance(c, Mapping) and "value" in c else c + return self.choice_direct(a).get("value", None) - def zipf_choice(self, a: list[T], n: int | None = None) -> T: + def zipf_choice_direct(self, a: list[T], n: int | None = None) -> T: """ Choose a value according to the Zipf distribution. @@ -142,14 +153,26 @@ def zipf_choice(self, a: list[T], n: int | None = None) -> T: 1/n times as frequently as the first value is chosen. :param a: The list of values to output, most frequent first. - Each element is either the value itself, or a mapping with - a key ``value`` and the key is the value to return. :return: The chosen value. """ if n is None: n = len(a) - c = random.choices(a, weights=zipf_weights(n))[0] - return c["value"] if isinstance(c, Mapping) and "value" in c else c + return random.choices(a, weights=zipf_weights(n))[0] + + def zipf_choice(self, a: list[Mapping[str, T]], n: int | None = None) -> T | None: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of rows to choose between, most frequent first. + Each element is a mapping with a key ``value`` and the key is the + value to return. + :return: The chosen value. + """ + c = self.zipf_choice_direct(a, n) + return c.get("value", None) def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ @@ -214,7 +237,9 @@ def _select_group(self, alts: list[dict[str, Any]]) -> Any: choice -= alt["count"] if choice < 0: return alt - raise Exception("Internal error: ran out of choices in _select_group") + raise NothingToGenerateException( + "Internal error: ran out of choices in _select_group" + ) def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: """ @@ -286,7 +311,9 @@ def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any def _check_generator_name(self, name: str) -> None: if name not in self.PERMITTED_SUBGENS: - raise Exception("%s is not a permitted generator", name) + raise InappropriateGeneratorException( + f"{name} is not a permitted generator" + ) def alternatives( self, diff --git a/datafaker/create.py b/datafaker/create.py index ce2a74e0..a877320d 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,6 +1,7 @@ """Functions and classes to create and populate the target database.""" import pathlib from collections import Counter +from types import ModuleType from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple from sqlalchemy import Connection, insert, inspect @@ -97,8 +98,7 @@ def create_db_vocab( def create_db_data( sorted_tables: Sequence[Table], - table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], + df_module: ModuleType, num_passes: int, ) -> RowCounts: """Connect to a database and populate it with data.""" @@ -108,8 +108,7 @@ def create_db_data( return create_db_data_into( sorted_tables, - table_generator_dict, - story_generator_list, + df_module, num_passes, dst_dsn, settings.dst_schema, @@ -118,8 +117,7 @@ def create_db_data( def create_db_data_into( sorted_tables: Sequence[Table], - table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], + df_module: ModuleType, num_passes: int, db_dsn: str, schema_name: str | None, @@ -145,12 +143,13 @@ def create_db_data_into( row_counts += populate( dst_conn, sorted_tables, - table_generator_dict, - story_generator_list, + df_module.table_generator_dict, + df_module.story_generator_list, ) return row_counts +# pylint: disable=too-many-instance-attributes class StoryIterator: """Iterates through all the rows produced by all the stories.""" @@ -305,7 +304,9 @@ def populate( story_iterator.insert() t = story_iterator.table_name() if t is None: - raise Exception("Internal error") + raise AssertionError( + "Internal error: story iterator returns None but not is_ended" + ) row_counts[t] = row_counts.get(t, 0) + 1 story_iterator.next() diff --git a/datafaker/generators.py b/datafaker/generators.py deleted file mode 100644 index 0b760e45..00000000 --- a/datafaker/generators.py +++ /dev/null @@ -1,2160 +0,0 @@ -"""Generator factories for making generators for single columns.""" - -import decimal -import math -import re -import typing -from abc import ABC, abstractmethod -from collections.abc import Mapping -from dataclasses import dataclass -from functools import lru_cache -from itertools import chain, combinations -from typing import Any, Callable, Iterable, Sequence, Union - -import mimesis -import mimesis.locales -import sqlalchemy -from sqlalchemy import Column, Connection, CursorResult, Engine, RowMapping, text -from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time, TypeEngine -from typing_extensions import Self - -from datafaker.base import DistributionGenerator -from datafaker.utils import T, logger - -NumericType = Union[int, float] - -# How many distinct values can we have before we consider a -# choice distribution to be infeasible? -MAXIMUM_CHOICES = 500 - -dist_gen = DistributionGenerator() -generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) - - -class Generator(ABC): - """ - Random data generator. - - A generator is specific to a particular column in a particular table in - a particluar database. - - A generator knows how to fetch its summary data from the database, how to calculate - its fit (if apropriate) and which function actually does the generation. - - It also knows these summary statistics for the column it was instantiated on, - and therefore knows how to generate fake data for that column. - """ - - @abstractmethod - def function_name(self) -> str: - """Get the name of the generator function to put into df.py.""" - - def name(self) -> str: - """ - Get the name of the generator. - - Usually the same as the function name, but can be different to distinguish - between generators that have the same function but different queries. - """ - return self.function_name() - - @abstractmethod - def nominal_kwargs(self) -> dict[str, str]: - """ - Get the kwargs the generator wants to be called with. - - The values will tend to be references to something in the src-stats.yaml - file. - For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will - provide the value stored in src-stats.yaml as - SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument - to the generator function. - """ - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """ - Get the SQL clauses to add to a SELECT ... FROM {table} query. - - Will add to SRC_STATS["auto__{table}"] - For example { - "count": { - "clause": "COUNT(*)", - "comment": "number of rows in table {table}" - }, "avg_thiscolumn": { - "clause": "AVG(thiscolumn)", - "comment": "Average value of thiscolumn in table {table}" - }} - will make the clause become: - "SELECT COUNT(*) AS count, AVG(thiscolumn) AS avg_thiscolumn FROM thistable" - and this will populate SRC_STATS["auto__thistable"]["results"][0]["count"] and - SRC_STATS["auto__thistable"]["results"][0]["avg_thiscolumn"] in the src-stats.yaml file. - """ - return {} - - def custom_queries(self) -> dict[str, dict[str, str]]: - """ - Get the SQL queries to add to SRC_STATS. - - Should be used for queries that do not follow the SELECT ... FROM table format - using aggregate queries, because these should use select_aggregate_clauses. - - For example {"myquery": { - "query": "SELECT one, too AS two FROM mytable WHERE too > 1", - "comment": "big enough one and two from table mytable" - }} - will populate SRC_STATS["myquery"]["results"][0]["one"] - and SRC_STATS["myquery"]["results"][0]["two"] - in the src-stats.yaml file. - - Keys should be chosen to minimize the chances of clashing with other queries, - for example "auto__{table}__{column}__{queryname}" - """ - return {} - - @abstractmethod - def actual_kwargs(self) -> dict[str, Any]: - """ - Get the kwargs (summary statistics) this generator is instantiated with. - - This must match `nominal_kwargs` in structure. - """ - - @abstractmethod - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - - def fit(self, default: float = -1) -> float: - """ - Return a value representing how well the distribution fits the real source data. - - 0.0 means "perfectly". - Returns default if no fitness has been defined. - """ - return default - - -class PredefinedGenerator(Generator): - """Generator built from an existing config.yaml.""" - - SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") - AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") - SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') - - def _get_src_stats_mentioned(self, val: Any) -> set[str]: - if not val: - return set() - if isinstance(val, str): - ss = self.SRC_STAT_NAME_RE.match(val) - if ss: - ss_name = ss.group(1) - logger.debug("Found SRC_STATS reference %s", ss_name) - return set([ss_name]) - logger.debug("Value %s does not seem to be a SRC_STATS reference", val) - return set() - if isinstance(val, list): - return set.union(*(self._get_src_stats_mentioned(v) for v in val)) - if isinstance(val, dict): - return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) - return set() - - def __init__( - self, - table_name: str, - generator_object: Mapping[str, Any], - config: Mapping[str, Any], - ): - """ - Initialise a generator from a config.yaml. - - :param config: The entire configuration. - :param generator_object: The part of the configuration at tables.*.row_generators - """ - logger.debug( - "Creating a PredefinedGenerator %s from table %s", - generator_object["name"], - table_name, - ) - self._table_name = table_name - self._name: str = generator_object["name"] - self._kwn: dict[str, str] = generator_object.get("kwargs", {}) - self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) - # Need to deal with this somehow (or remove it from the schema) - self._argn: list[str] = generator_object.get("args", []) - self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} - self._custom_queries = {} - for sstat in config.get("src-stats", []): - name: str = sstat["name"] - dpq = sstat.get("dp-query", None) - query = sstat.get( - "query", dpq - ) # ... should we really be combining query and dp-query? - comments = sstat.get("comments", []) - if name in self._src_stats_mentioned: - logger.debug("Found a src-stats entry for %s", name) - # This query is one that this generator is interested in - sam = None if query is None else self.SELECT_AGGREGATE_RE.match(query) - # sam.group(2) is the table name from the FROM clause of the query - if sam and name == f"auto__{sam.group(2)}": - # name is auto__{table_name}, so it's a select_aggregate, so we split up its clauses - sacs = [ - self.AS_CLAUSE_RE.match(clause) - for clause in sam.group(1).split(",") - ] - # Work out what select_aggregate_clauses this represents - for sac in sacs: - if sac is not None: - comment = comments.pop() if comments else None - self._select_aggregate_clauses[sac.group(2)] = { - "clause": sac.group(1), - "comment": comment, - } - else: - # some other name, so must be a custom query - logger.debug("Custom query %s is '%s'", name, query) - self._custom_queries[name] = { - "query": query, - "comment": comments[0] if comments else None, - } - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return self._name - - def nominal_kwargs(self) -> dict[str, str]: - """Get the arguments to be entered into ``config.yaml``.""" - return self._kwn - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - return self._select_aggregate_clauses - - def custom_queries(self) -> dict[str, dict[str, str]]: - """Get the queries the generators need to call.""" - return self._custom_queries - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - # Run the queries from nominal_kwargs - # ... - logger.error("PredefinedGenerator.actual_kwargs not implemented yet") - return {} - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - # Call the function if we can. This could be tricky... - # ... - logger.error("PredefinedGenerator.generate_data not implemented yet") - return [] - - -class GeneratorFactory(ABC): - """A factory for making generators appropriate for a database column.""" - - @abstractmethod - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - - -class Buckets: - """ - Measured buckets for a real distribution. - - Finds the real distribution of continuous data so that we can measure - the fit of generators against it. - """ - - def __init__( - self, - engine: Engine, - table_name: str, - column_name: str, - mean: float, - stddev: float, - count: int, - ): - """Initialise a Buckets object.""" - with engine.connect() as connection: - raw_buckets = connection.execute( - text( - f"SELECT COUNT({column_name}) AS f," - f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" - f" FROM {table_name} GROUP BY b" - ) - ) - self.buckets: Sequence[int] = [0] * 10 - for rb in raw_buckets: - if rb.b is not None: - bucket = min(9, max(0, int(rb.b) + 1)) - self.buckets[bucket] += rb.f / count - self.mean = mean - self.stddev = stddev - - @classmethod - def make_buckets( - cls, engine: Engine, table_name: str, column_name: str - ) -> Self | None: - """ - Construct a Buckets object. - - Calculates the mean and standard deviation of the values in the column - specified and makes ten buckets, centered on the mean and each half - a standard deviation wide (except for the end two that extend to - infinity). Each bucket will be set to the count of the number of values - in the column within that bucket. - """ - with engine.connect() as connection: - result = connection.execute( - text( - f"SELECT AVG({column_name}) AS mean," - f" STDDEV({column_name}) AS stddev," - f" COUNT({column_name}) AS count FROM {table_name}" - ) - ).first() - if result is None or result.stddev is None or getattr(result, "count") < 2: - return None - try: - buckets = cls( - engine, - table_name, - column_name, - result.mean, - result.stddev, - getattr(result, "count"), - ) - except sqlalchemy.exc.DatabaseError as exc: - logger.debug("Failed to instantiate Buckets object: %s", exc) - return None - return buckets - - def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: - """Figure out the fit from bucket counts from the generator distribution.""" - return fit_from_buckets(self.buckets, bucket_counts) - - def fit_from_values(self, values: list[float]) -> float: - """Figure out the fit from samples from the generator distribution.""" - buckets = [0] * 10 - x = self.mean - 2 * self.stddev - w = self.stddev / 2 - for v in values: - b = min(9, max(0, int((v - x) / w))) - buckets[b] += 1 - return self.fit_from_counts(buckets) - - -class MultiGeneratorFactory(GeneratorFactory): - """A composite factory.""" - - def __init__(self, factories: list[GeneratorFactory]): - """Initialise a MultiGeneratorFactory.""" - super().__init__() - self.factories = factories - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - return [ - generator - for factory in self.factories - for generator in factory.get_generators(columns, engine) - ] - - -class MimesisGeneratorBase(Generator): - """Base class for a generator using Mimesis.""" - - def __init__( - self, - function_name: str, - ): - """ - Initialise a generator that uses Mimesis. - - :param function_name: is relative to 'generic', for example 'person.name'. - """ - super().__init__() - f = generic - for part in function_name.split("."): - if not hasattr(f, part): - raise Exception( - f"Mimesis does not have a function {function_name}: {part} not found" - ) - f = getattr(f, part) - if not callable(f): - raise Exception( - f"Mimesis object {function_name} is not a callable," - " so cannot be used as a generator" - ) - self._name = "generic." + function_name - self._generator_function = f - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return self._name - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [self._generator_function() for _ in range(count)] - - -class MimesisGenerator(MimesisGeneratorBase): - """A generator using Mimesis.""" - - def __init__( - self, - function_name: str, - value_fn: Callable[[Any], float] | None = None, - buckets: Buckets | None = None, - ): - """ - Initialise a generator using Mimesis. - - :param function_name: is relative to 'generic', for example 'person.name'. - :param value_fn: Function to convert generator output to floats, if needed. The values - thus produced are compared against the buckets to estimate the fit. - :param buckets: The distribution of string lengths in the real data. If this is None - then the fit method will return None. - """ - super().__init__(function_name) - if buckets is None: - self._fit = None - return - samples = self.generate_data(400) - if value_fn: - samples = [value_fn(s) for s in samples] - self._fit = buckets.fit_from_values(samples) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return self._name - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return {} - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return {} - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default if self._fit is None else self._fit - - -class MimesisGeneratorTruncated(MimesisGenerator): - """A string generator using Mimesis that must fit within a certain number of characters.""" - - def __init__( - self, - function_name: str, - length: int, - value_fn: Callable[[Any], float] | None = None, - buckets: Buckets | None = None, - ): - """Initialise a MimesisGeneratorTruncated.""" - self._length = length - super().__init__(function_name, value_fn, buckets) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.truncated_string" - - def name(self) -> str: - """Get the name of the generator.""" - return f"{self._name} [truncated to {self._length}]" - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "subgen_fn": self._name, - "params": {}, - "length": self._length, - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "subgen_fn": self._name, - "params": {}, - "length": self._length, - } - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [self._generator_function()[: self._length] for _ in range(count)] - - -class MimesisDateTimeGenerator(MimesisGeneratorBase): - """DateTime generator using Mimesis.""" - - def __init__( - self, - column: Column, - function_name: str, - min_year: str, - max_year: str, - start: int, - end: int, - ) -> None: - """ - Initialise a MimesisDateTimeGenerator. - - :param column: The column to generate into - :param function_name: The name of the mimesis function - :param min_year: SQL expression extracting the minimum year - :param min_year: SQL expression extracting the maximum year - :param start: The actual first year found - :param end: The actual last year found - """ - super().__init__(function_name) - self._column = column - self._max_year = max_year - self._min_year = min_year - self._start = start - self._end = end - - @classmethod - def make_singleton( - cls, column: Column, engine: Engine, function_name: str - ) -> Sequence[Generator]: - """Make the appropriate generation configuration for this column.""" - extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" - max_year = f"MAX({extract_year})" - min_year = f"MIN({extract_year})" - with engine.connect() as connection: - result = connection.execute( - text( - f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" - ) - ).first() - if result is None or result.start is None or result.end is None: - return [] - return [ - MimesisDateTimeGenerator( - column, - function_name, - min_year, - max_year, - int(result.start), - int(result.end), - ) - ] - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "start": ( - f'SRC_STATS["auto__{self._column.table.name}"]["results"]' - f'[0]["{self._column.name}__start"]' - ), - "end": ( - f'SRC_STATS["auto__{self._column.table.name}"]["results"]' - f'[0]["{self._column.name}__end"]' - ), - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "start": self._start, - "end": self._end, - } - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - return { - f"{self._column.name}__start": { - "clause": self._min_year, - "comment": ( - f"Earliest year found for column {self._column.name}" - f" in table {self._column.table.name}" - ), - }, - f"{self._column.name}__end": { - "clause": self._max_year, - "comment": ( - f"Latest year found for column {self._column.name}" - f" in table {self._column.table.name}" - ), - }, - } - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [ - self._generator_function(start=self._start, end=self._end) - for _ in range(count) - ] - - -def get_column_type(column: Column) -> TypeEngine: - """Get the type of the column, generic if possible.""" - try: - return column.type.as_generic() - except NotImplementedError: - return column.type - - -class MimesisStringGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return strings.""" - - GENERATOR_NAMES = [ - "address.calling_code", - "address.city", - "address.continent", - "address.country", - "address.country_code", - "address.postal_code", - "address.province", - "address.street_number", - "address.street_name", - "address.street_suffix", - "person.blood_type", - "person.email", - "person.first_name", - "person.last_name", - "person.full_name", - "person.gender", - "person.language", - "person.nationality", - "person.occupation", - "person.password", - "person.title", - "person.university", - "person.username", - "person.worldview", - "text.answer", - "text.color", - "text.level", - "text.quote", - "text.sentence", - "text.text", - "text.word", - ] - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - column_type = get_column_type(column) - if not isinstance(column_type, String): - return [] - try: - buckets = Buckets.make_buckets( - engine, - column.table.name, - f"LENGTH({column.name})", - ) - fitness_fn = len - except Exception: - # Some column types that appear to be strings (such as enums) - # cannot have their lengths measured. In this case we cannot - # detect fitness using lengths. - buckets = None - fitness_fn = None - length = column_type.length - if length: - return list( - map( - lambda gen: MimesisGeneratorTruncated( - gen, length, fitness_fn, buckets - ), - self.GENERATOR_NAMES, - ) - ) - return list( - map( - lambda gen: MimesisGenerator(gen, fitness_fn, buckets), - self.GENERATOR_NAMES, - ) - ) - - -class MimesisFloatGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return floating point numbers.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - if not isinstance(get_column_type(column), Numeric): - return [] - return list( - map( - MimesisGenerator, - [ - "person.height", - ], - ) - ) - - -class MimesisDateGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return dates.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Date): - return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") - - -class MimesisDateTimeGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return datetimes.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, DateTime): - return [] - return MimesisDateTimeGenerator.make_singleton( - column, engine, "datetime.datetime" - ) - - -class MimesisTimeGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return times.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Time): - return [] - return [MimesisGenerator("datetime.time")] - - -class MimesisIntegerGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return integers.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - return [MimesisGenerator("person.weight")] - - -def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: - """Calculate the fit by comparing a pair of lists of buckets.""" - sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) - count = len(ys) - return sum_diff_squared / (count * count) - - -class ContinuousDistributionGenerator(Generator): - """Base class for generators producing continuous distributions.""" - - expected_buckets: Sequence[NumericType] = [] - - def __init__(self, table_name: str, column_name: str, buckets: Buckets): - """Initialise a ContinuousDistributionGenerator.""" - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.buckets = buckets - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "mean": ( - f'SRC_STATS["auto__{self.table_name}"]["results"]' - f'[0]["mean__{self.column_name}"]' - ), - "sd": ( - f'SRC_STATS["auto__{self.table_name}"]["results"]' - f'[0]["stddev__{self.column_name}"]' - ), - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - if self.buckets is None: - return {} - return { - "mean": self.buckets.mean, - "sd": self.buckets.stddev, - } - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - clauses = super().select_aggregate_clauses() - return { - **clauses, - f"mean__{self.column_name}": { - "clause": f"AVG({self.column_name})", - "comment": f"Mean of {self.column_name} from table {self.table_name}", - }, - f"stddev__{self.column_name}": { - "clause": f"STDDEV({self.column_name})", - "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", - }, - } - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - if self.buckets is None: - return default - return self.buckets.fit_from_counts(self.expected_buckets) - - -class GaussianGenerator(ContinuousDistributionGenerator): - """Generator producing numbers in a Gaussian (normal) distribution.""" - - expected_buckets = [ - 0.0227, - 0.0441, - 0.0918, - 0.1499, - 0.1915, - 0.1915, - 0.1499, - 0.0918, - 0.0441, - 0.0227, - ] - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.normal" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [ - dist_gen.normal(self.buckets.mean, self.buckets.stddev) - for _ in range(count) - ] - - -class UniformGenerator(ContinuousDistributionGenerator): - """Generator producing numbers in a uniform distribution.""" - - expected_buckets = [ - 0, - 0.06698, - 0.14434, - 0.14434, - 0.14434, - 0.14434, - 0.14434, - 0.14434, - 0.06698, - 0, - ] - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.uniform_ms" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [ - dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) - for _ in range(count) - ] - - -class ContinuousDistributionGeneratorFactory(GeneratorFactory): - """All generators that want an average and standard deviation.""" - - def _get_generators_from_buckets( - self, - _engine: Engine, - table_name: str, - column_name: str, - buckets: Buckets, - ) -> Sequence[Generator]: - return [ - GaussianGenerator(table_name, column_name, buckets), - UniformGenerator(table_name, column_name, buckets), - ] - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - column_name = column.name - table_name = column.table.name - buckets = Buckets.make_buckets(engine, table_name, column_name) - if buckets is None: - return [] - return self._get_generators_from_buckets( - engine, table_name, column_name, buckets - ) - - -class LogNormalGenerator(Generator): - """Generator producing numbers in a log-normal distribution.""" - - # TODO: figure out the real buckets here (this was from a random sample in R) - expected_buckets = [ - 0, - 0, - 0, - 0.28627, - 0.40607, - 0.14937, - 0.06735, - 0.03492, - 0.01918, - 0.03684, - ] - - def __init__( - self, - table_name: str, - column_name: str, - buckets: Buckets, - logmean: float, - logstddev: float, - ): - """Initialise a LogNormalGenerator.""" - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.buckets = buckets - self.logmean = logmean - self.logstddev = logstddev - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.lognormal" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "logmean": ( - f'SRC_STATS["auto__{self.table_name}"]["results"][0]' - f'["logmean__{self.column_name}"]' - ), - "logsd": ( - f'SRC_STATS["auto__{self.table_name}"]["results"][0]' - f'["logstddev__{self.column_name}"]' - ), - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "logmean": self.logmean, - "logsd": self.logstddev, - } - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - clauses = super().select_aggregate_clauses() - return { - **clauses, - f"logmean__{self.column_name}": { - "clause": ( - f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" - " ELSE NULL END)" - ), - "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", - }, - f"logstddev__{self.column_name}": { - "clause": ( - f"STDDEV(CASE WHEN 0<{self.column_name}" - f" THEN LN({self.column_name}) ELSE NULL END)" - ), - "comment": ( - f"Standard deviation of logs of {self.column_name}" - f" from table {self.table_name}" - ), - }, - } - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - if self.buckets is None: - return default - return self.buckets.fit_from_counts(self.expected_buckets) - - -class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): - """All generators that want an average and standard deviation of log data.""" - - def _get_generators_from_buckets( - self, - engine: Engine, - table_name: str, - column_name: str, - buckets: Buckets, - ) -> Sequence[Generator]: - with engine.connect() as connection: - result = connection.execute( - text( - f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" - " ELSE NULL END) AS logmean," - f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" - f" AS logstddev FROM {table_name}" - ) - ).first() - if result is None or result.logstddev is None: - return [] - return [ - LogNormalGenerator( - table_name, - column_name, - buckets, - float(result.logmean), - float(result.logstddev), - ) - ] - - -def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: - """ - Get a zipf distribution for a certain number of items. - - :param total: The total number of items to be distributed. - :param bins: The total number of bins to distribute the items into. - :return: A generator of the number of items in each bin, from the - largest to the smallest. - """ - basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) - bd_remaining = sum(basic_dist) - for b in basic_dist: - # yield b/bd_remaining of the `total` remaining - if bd_remaining == 0: - yield 0 - else: - x = math.floor(0.5 + total * b / bd_remaining) - bd_remaining -= x * bd_remaining / total - total -= x - yield x - - -class ChoiceGenerator(Generator): - """Base generator for all generators producing choices of items.""" - - STORE_COUNTS = False - - def __init__( - self, - table_name: str, - column_name: str, - values: list[Any], - counts: list[int], - sample_count: int | None = None, - suppress_count: int = 0, - ) -> None: - """Initialise a ChoiceGenerator.""" - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.values = values - estimated_counts = self.get_estimated_counts(counts) - self._fit = fit_from_buckets(counts, estimated_counts) - - extra_results = "" - extra_expo = "" - extra_comment = "" - if self.STORE_COUNTS: - extra_results = f", COUNT({column_name}) AS count" - extra_expo = ", count" - extra_comment = " and their counts" - if suppress_count == 0: - if sample_count is None: - self._query = ( - f"SELECT {column_name} AS value{extra_results} FROM {table_name}" - f" WHERE {column_name} IS NOT NULL GROUP BY value" - f" ORDER BY COUNT({column_name}) DESC" - ) - self._comment = ( - f"All the values{extra_comment} that appear in column {column_name}" - f" of table {table_name}" - ) - self._annotation = None - else: - self._query = ( - f"SELECT {column_name} AS value{extra_results} FROM" - f" (SELECT {column_name} FROM {table_name}" - f" WHERE {column_name} IS NOT NULL" - f" ORDER BY RANDOM() LIMIT {sample_count})" - f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" - ) - self._comment = ( - f"The values{extra_comment} that appear in column {column_name}" - f" of a random sample of {sample_count} rows of table {table_name}" - ) - self._annotation = "sampled" - else: - if sample_count is None: - self._query = ( - f"SELECT value{extra_expo} FROM" - f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" - f" FROM {table_name} WHERE {column_name} IS NOT NULL" - f" GROUP BY value ORDER BY count DESC) AS _inner" - f" WHERE {suppress_count} < count" - ) - self._comment = ( - f"All the values{extra_comment} that appear in column {column_name}" - f" of table {table_name} more than {suppress_count} times" - ) - self._annotation = "suppressed" - else: - self._query = ( - f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" - f" (SELECT {column_name} AS value FROM {table_name}" - f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" - f" AS _inner GROUP BY value ORDER BY count DESC)" - f" AS _inner WHERE {suppress_count} < count" - ) - self._comment = ( - f"The values{extra_comment} that appear more than {suppress_count} times" - f" in column {column_name}, out of a random sample of {sample_count} rows" - f" of table {table_name}" - ) - self._annotation = "sampled and suppressed" - - @abstractmethod - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', - } - - def name(self) -> str: - """Get the name of the generator.""" - n = super().name() - if self._annotation is None: - return n - return f"{n} [{self._annotation}]" - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "a": self.values, - } - - def custom_queries(self) -> dict[str, dict[str, str]]: - """Get the queries the generators need to call.""" - qs = super().custom_queries() - return { - **qs, - f"auto__{self.table_name}__{self.column_name}": { - "query": self._query, - "comment": self._comment, - }, - } - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default if self._fit is None else self._fit - - -class ZipfChoiceGenerator(ChoiceGenerator): - """Generator producing items in a Zipf distribution.""" - - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - return list(zipf_distribution(sum(counts), len(counts))) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.zipf_choice" - - def generate_data(self, count: int) -> list[float]: - """Generate ``count`` random data points for this column.""" - return [ - dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) - ] - - -def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: - """ - Construct a distribution putting ``total`` items uniformly into ``bins`` bins. - - If they don't fit exactly evenly, the earlier bins will have one more - item than the later bins so the total is as required. - """ - p = total // bins - n = total % bins - for _ in range(0, n): - yield p + 1 - for _ in range(n, bins): - yield p - - -class UniformChoiceGenerator(ChoiceGenerator): - """A generator producing values, each roughly as frequently as each other.""" - - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - return list(uniform_distribution(sum(counts), len(counts))) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.choice" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [dist_gen.choice(self.values) for _ in range(count)] - - -class WeightedChoiceGenerator(ChoiceGenerator): - """Choice generator that matches the source data's frequency.""" - - STORE_COUNTS = True - - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - return counts - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.weighted_choice" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [dist_gen.weighted_choice(self.values) for _ in range(count)] - - -class ValueGatherer: - """ - Gathers values from a query of values and counts. - - The query must return columns ``v`` for a value and ``f`` for the - count of how many of those values there are. - These values will be gathered into a number of properties: - ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts - in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and - ``count`` giving these values and counts. ``counts_not_suppressed``, - ``values_not_suppressed`` and ``cvs_not_suppressed`` are the - equivalents with the counts less than or equal to ``suppress_count`` - removed. - - :param suppress_count: value with a count of this or fewer will be excluded - from the suppressed values. - """ - - def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: - """Initialise a ValueGatherer.""" - values = [] # All values found - counts = [] # The number or each value - cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times - counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if isinstance(v, decimal.Decimal): - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if suppress_count < c: - counts_not_suppressed.append(c) - v = result.v - if isinstance(v, decimal.Decimal): - v = float(v) - values_not_suppressed.append(v) - cvs_not_suppressed.append({"value": v, "count": c}) - self.values = values - self.counts = counts - self.cvs = cvs - self.values_not_suppressed = values_not_suppressed - self.counts_not_suppressed = counts_not_suppressed - self.cvs_not_suppressed = cvs_not_suppressed - - -class ChoiceGeneratorFactory(GeneratorFactory): - """All generators that want an average and standard deviation.""" - - SAMPLE_COUNT = MAXIMUM_CHOICES - SUPPRESS_COUNT = 5 - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - column_name = column.name - table_name = column.table.name - generators = [] - with engine.connect() as connection: - results = connection.execute( - text( - f"SELECT {column_name} AS v, COUNT({column_name})" - f" AS f FROM {table_name} GROUP BY v" - f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" - ) - ) - if results is not None and results.rowcount <= MAXIMUM_CHOICES: - vg = ValueGatherer(results, self.SUPPRESS_COUNT) - if vg.counts: - generators += [ - ZipfChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - UniformChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - WeightedChoiceGenerator( - table_name, column_name, vg.cvs, vg.counts - ), - ] - results = connection.execute( - text( - f"SELECT v, COUNT(v) AS f FROM" - f" (SELECT {column_name} as v FROM {table_name}" - f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" - f" AS _inner GROUP BY v ORDER BY f DESC" - ) - ) - if results is not None: - vg = ValueGatherer(results, self.SUPPRESS_COUNT) - if vg.counts: - generators += [ - ZipfChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - UniformChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - WeightedChoiceGenerator( - table_name, column_name, vg.cvs, vg.counts - ), - ] - generators += [ - ZipfChoiceGenerator( - table_name, - column_name, - vg.values, - vg.counts, - sample_count=self.SAMPLE_COUNT, - ), - UniformChoiceGenerator( - table_name, - column_name, - vg.values, - vg.counts, - sample_count=self.SAMPLE_COUNT, - ), - WeightedChoiceGenerator( - table_name, - column_name, - vg.cvs, - vg.counts, - sample_count=self.SAMPLE_COUNT, - ), - ] - if vg.counts_not_suppressed: - generators += [ - ZipfChoiceGenerator( - table_name, - column_name, - vg.values_not_suppressed, - vg.counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - UniformChoiceGenerator( - table_name, - column_name, - vg.values_not_suppressed, - vg.counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - WeightedChoiceGenerator( - table_name=table_name, - column_name=column_name, - values=vg.cvs_not_suppressed, - counts=vg.counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - ] - return generators - - -class ConstantGenerator(Generator): - """Generator that always produces the same value.""" - - def __init__(self, value: Any) -> None: - """Initialise the ConstantGenerator.""" - super().__init__() - self.value = value - self.repr = repr(value) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.constant" - - def nominal_kwargs(self) -> dict[str, str]: - """Get the arguments to be entered into ``config.yaml``.""" - return {"value": self.repr} - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return {"value": self.value} - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [self.value for _ in range(count)] - - -class ConstantGeneratorFactory(GeneratorFactory): - """Just the null generator.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate for these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - if column.nullable: - return [ConstantGenerator(None)] - c_type = get_column_type(column) - if isinstance(c_type, String): - return [ConstantGenerator("")] - if isinstance(c_type, Numeric): - return [ConstantGenerator(0.0)] - if isinstance(c_type, Integer): - return [ConstantGenerator(0)] - return [] - - -class MultivariateNormalGenerator(Generator): - """Generator of multiple values drawn from a multivariate normal distribution.""" - - def __init__( - self, - table_name: str, - column_names: list[str], - query: str, - covariates: RowMapping, - function_name: str, - ) -> None: - """Initialise a MultivariateNormalGenerator.""" - self._table = table_name - self._columns = column_names - self._query = query - self._covariates = covariates - self._function_name = function_name - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen." + self._function_name - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', - } - - def custom_queries(self) -> dict[str, Any]: - """Get the queries the generators need to call.""" - cols = ", ".join(self._columns) - return { - f"auto__cov__{self._table}": { - "comment": ( - f"Means and covariate matrix for the columns {cols}," - " so that we can produce the relatedness between these in the fake data." - ), - "query": self._query, - } - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return {"cov": self._covariates} - - def generate_data(self, count: int) -> list[Any]: - """Generate 'count' random data points for this column.""" - return [ - getattr(dist_gen, self._function_name)(self._covariates) - for _ in range(count) - ] - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default - - -class MultivariateNormalGeneratorFactory(GeneratorFactory): - """Normal distribution generator factory.""" - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "multivariate_normal" - - def query_predicate(self, column: Column) -> str: - """Get the SQL expression for whether this column should be queried.""" - return column.name + " IS NOT NULL" - - def query_var(self, column: str) -> str: - """Get the SQL expression of the value to query for this column.""" - return column - - def query( - self, - table: str, - columns: list[Column], - predicates: list[str] = [], - group_by_clause: str = "", - constant_clauses: str = "", - constants: str = "", - suppress_count: int = 1, - sample_count: int | None = None, - ) -> str: - """ - Get a query for the basics for multivariate normal/lognormal parameters. - - :param table: The name of the table to be queried. - :param columns: The columns in the multivariate distribution. - :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. - :param group_by_clause: Any GROUP BY clause (starting with " GROUP BY " if not ""). - :param constant_clauses: Extra output columns in the outer SELECT clause, such - as ", _q.column_one AS k1, _q.column_two AS k2". Note the initial comma. - :param constants: Extra output columns in the inner SELECT clause. Used to - deliver columns to the outer select, such as ", column_one, column_two". - Note the initial comma. - :param suppress_count: a group smaller than this will be suppressed. - :param sample_count: this many samples will be taken from each partition. - """ - preds = [self.query_predicate(col) for col in columns] + predicates - where = " WHERE " + " AND ".join(preds) if preds else "" - avgs = "".join( - f", AVG({self.query_var(col.name)}) AS m{i}" - for i, col in enumerate(columns) - ) - multiples = "".join( - f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" - for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[: iy + 1]) - ) - means = "".join(f", _q.m{i}" for i in range(len(columns))) - covs = "".join( - ( - f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" - f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" - ) - for iy in range(len(columns)) - for ix in range(iy + 1) - ) - if sample_count is None: - subquery = table + where - else: - subquery = ( - f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" - f" LIMIT {sample_count}) AS _sampled" - ) - # if there are any numeric columns we need at least two rows to make any (co)variances at all - suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" - return ( - f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" - f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" - f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" - ) - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators for these columns.""" - # For the case of one column we'll use GaussianGenerator - if len(columns) < 2: - return [] - # All columns must be numeric - for c in columns: - ct = get_column_type(c) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - column_names = [c.name for c in columns] - table = columns[0].table.name - query = self.query(table, columns) - with engine.connect() as connection: - try: - covariates = connection.execute(text(query)).mappings().first() - except Exception as e: - logger.debug("SQL query %s failed with error %s", query, e) - return [] - if not covariates or covariates["c0_0"] is None: - return [] - return [ - MultivariateNormalGenerator( - table, - column_names, - query, - covariates, - self.function_name(), - ) - ] - - -class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): - """Multivariate lognormal generator factory.""" - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "multivariate_lognormal" - - def query_predicate(self, column: Column) -> str: - """Get the SQL expression for whether this column should be queried.""" - return f"COALESCE(0 < {column.name}, FALSE)" - - def query_var(self, column: str) -> str: - """Get the expression to query for, for this column.""" - return f"LN({column})" - - -def text_list(items: Iterable[str]) -> str: - """Concatenate the items with commas and one "and".""" - item_i = iter(items) - try: - last_item = next(item_i) - except StopIteration: - return "" - try: - so_far = next(item_i) - except StopIteration: - return last_item - for item in item_i: - so_far += ", " + last_item - last_item = item - return so_far + " and " + last_item - - -@dataclass -class RowPartition: - """A partition where all the rows have the same pattern of NULLs.""" - - query: str - # list of numeric columns - included_numeric: list[Column] - # map of indices to column names that are being grouped by. - # The indices are indices of where they need to be inserted into - # the generator outputs. - included_choice: dict[int, str] - # map of column names to clause that defines the partition - # such as "mycolumn IS NULL" - excluded_columns: dict[str, str] - # map of constant outputs that need to be inserted into the - # list of included column values (so once the generator has - # been run and the included_choice values have been - # added): {index: value} - constant_outputs: dict[int, Any] - # The actual covariates from the source database - covariates: Sequence[RowMapping] - - def comment(self) -> str: - """Make an appropriate comment for this partition.""" - caveat = "" - if self.included_choice: - caveat = f" (for each possible value of {text_list(self.included_choice.values())})" - if not self.included_numeric: - return f"Number of rows for which {text_list(self.excluded_columns.values())}{caveat}" - if not self.excluded_columns: - where = "" - else: - where = f" where {text_list(self.excluded_columns.values())}" - if len(self.included_numeric) == 1: - return ( - f"Mean and variance for column {self.included_numeric[0].name}{where}." - ) - return ( - "Means and covariate matrix for the columns " - f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" - " produce the relatedness between these in the fake data." - ) - - -class NullPartitionedNormalGenerator(Generator): - """ - A generator of mixed numeric and non-numeric data. - - Generates data that matches the source data in - missingness, choice of non-numeric data and numeric - data. - - For the numeric data to be generated, samples of rows for each - combination of non-numeric values and missingness. If any such - combination has only one line in the source data (or sample of - the source data if sampling), it will not be generated as a - covariate matrix cannot be generated from one source row - (although if the data is all non-numeric values and nulls, single - rows are used because no covariate matrix is required for this). - """ - - def __init__( - self, - query_name: str, - partitions: dict[int, RowPartition], - function_name: str = "grouped_multivariate_lognormal", - name_suffix: str | None = None, - partition_count_query: str | None = None, - partition_counts: Iterable[RowMapping] = [], - partition_count_comment: str | None = None, - ): - """Initialise a NullPartitionedNormalGenerator.""" - self._query_name = query_name - self._partitions = partitions - self._function_name = function_name - self._partition_count_query = partition_count_query - self._partition_counts = [dict(pc) for pc in partition_counts] - self._partition_count_comment = partition_count_comment - if name_suffix: - self._name = f"null-partitioned {function_name} [{name_suffix}]" - else: - self._name = f"null-partitioned {function_name}" - - def name(self) -> str: - """Get the name of the generator.""" - return self._name - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.alternatives" - - def _nominal_kwargs_with_combinations( - self, index: int, partition: RowPartition - ) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml`` for a single partition.""" - count = ( - 'sum(r["count"] for r in' - f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' - ) - if not partition.included_numeric and not partition.included_choice: - return { - "count": count, - "name": '"constant"', - "params": {"value": [None] * len(partition.constant_outputs)}, - } - covariates = { - "covs": f'SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"]' - } - if not partition.constant_outputs: - return { - "count": count, - "name": f'"{self._function_name}"', - "params": covariates, - } - return { - "count": count, - "name": '"with_constants_at"', - "params": { - "constants_at": partition.constant_outputs, - "subgen": f'"{self._function_name}"', - "params": covariates, - }, - } - - def _count_query_name(self) -> str: - return f"auto__cov__{self._query_name}__counts" - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "alternative_configs": [ - self._nominal_kwargs_with_combinations(index, self._partitions[index]) - for index in range(len(self._partitions)) - ], - "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', - } - - def custom_queries(self) -> dict[str, Any]: - """Get the queries the generators need to call.""" - partitions = { - f"auto__cov__{self._query_name}__alt_{index}": { - "comment": partition.comment(), - "query": partition.query, - } - for index, partition in self._partitions.items() - } - if not self._partition_count_query: - return partitions - return { - self._count_query_name(): { - "comment": self._partition_count_comment, - "query": self._partition_count_query, - }, - **partitions, - } - - def _actual_kwargs_with_combinations( - self, partition: RowPartition - ) -> dict[str, Any]: - count = sum(row["count"] for row in partition.covariates) - if not partition.included_numeric and not partition.included_choice: - return { - "count": count, - "name": "constant", - "params": {"value": [None] * len(partition.excluded_columns)}, - } - covariates = { - "covs": partition.covariates, - } - if not partition.constant_outputs: - return { - "count": count, - "name": self._function_name, - "params": covariates, - } - return { - "count": count, - "name": "with_constants_at", - "params": { - "constants_at": partition.constant_outputs, - "subgen": self._function_name, - "params": covariates, - }, - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "alternative_configs": [ - self._actual_kwargs_with_combinations(self._partitions[index]) - for index in range(len(self._partitions)) - ], - "counts": self._partition_counts, - } - - def generate_data(self, count: int) -> list[Any]: - """Generate 'count' random data points for this column.""" - kwargs = self.actual_kwargs() - return [dist_gen.alternatives(**kwargs) for _ in range(count)] - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default - - -def is_numeric(col: Column) -> bool: - """Test if this column stores a numeric value.""" - ct = get_column_type(col) - return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys - - -def powerset(xs: list[T]) -> Iterable[Iterable[T]]: - """Get a list of all sublists of ``input``.""" - return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) - - -@dataclass -class NullableColumn: - """A reference to a nullable column whose nullability is part of a partitioning.""" - - column: Column - # The bit (power of two) of the number of the partition in the partition sizes list - bitmask: int - - -class NullPatternPartition: - """Get the definition of a partition (in other words, what makes it not another partition).""" - - def __init__( - self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] - ): - """Initialise a pattern of nulls which can be queried for.""" - self.index = sum(nc.bitmask for nc in partition_nonnulls) - nonnull_columns = {nc.column.name for nc in partition_nonnulls} - self.included_numeric: list[Column] = [] - self.included_choice: dict[int, str] = {} - self.group_by_clause = "" - self.constant_clauses = "" - self.constants = "" - self.excluded: dict[str, str] = {} - self.predicates: list[str] = [] - self.nones: dict[int, None] = {} - for col_index, column in enumerate(columns): - col_name = column.name - if col_name in nonnull_columns or not column.nullable: - if is_numeric(column): - self.included_numeric.append(column) - else: - index = len(self.included_numeric) + len(self.included_choice) - self.included_choice[index] = col_name - if self.group_by_clause: - self.group_by_clause += ", " + col_name - else: - self.group_by_clause = " GROUP BY " + col_name - self.constant_clauses += f", _q.{col_name} AS k{index}" - self.constants += ", " + col_name - else: - self.excluded[col_name] = f"{col_name} IS NULL" - self.predicates.append(f"{col_name} IS NULL") - self.nones[col_index] = None - - -class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): - """Produces null partitioned generators, for complex interdependent data.""" - - SAMPLE_COUNT = MAXIMUM_CHOICES - SUPPRESS_COUNT = 5 - EMPTY_RESULT = [ - RowMapping( - parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), - processors=None, - key_to_index={"count": 0}, - data=(0,), - ) - ] - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "grouped_multivariate_normal" - - def query_predicate(self, column: Column) -> str: - """Get a SQL expression that is true when ``column`` is available for analysis.""" - if is_numeric(column): - # x <> x + 1 ensures that x is not infinity or NaN - return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" - return f"{column.name} IS NOT NULL" - - def query_var(self, column: str) -> str: - """Return the expression we are querying for in this column.""" - return column - - def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: - """Get a list of nullable columns together with bitmasks.""" - out: list[NullableColumn] = [] - for col in columns: - if col.nullable: - out.append( - NullableColumn( - column=col, - bitmask=2 ** len(out), - ) - ) - return out - - def get_partition_count_query( - self, ncs: list[NullableColumn], table: str, where: str | None = None - ) -> str: - """ - Get a SQL expression returning columns ``count`` and ``index``. - - Each row returned represents one of the null pattern partitions. - ``index`` is the bitmask of all those nullable columns that are not null for - this partition, and ``count`` is the total number of rows in this partition. - """ - index_exp = " + ".join( - f"CASE WHEN {self.query_predicate(nc.column)} THEN {nc.bitmask} ELSE 0 END" - for nc in ncs - ) - if where is None: - return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' - return ( - 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' - f' {index_exp} AS "index"' - f' FROM {table} GROUP BY "index") AS _q {where}' - ) - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get any appropriate generators for these columns.""" - if len(columns) < 2: - return [] - nullable_columns = self.get_nullable_columns(columns) - if not nullable_columns: - return [] - table = columns[0].table.name - query_name = f"{table}__{columns[0].name}" - # Partitions for minimal suppression and no sampling - row_partitions_maximal: dict[int, RowPartition] = {} - # Partitions for normal suppression and severe sampling - row_partitions_ss: dict[int, RowPartition] = {} - for partition_nonnulls in powerset(nullable_columns): - partition_def = NullPatternPartition(columns, partition_nonnulls) - query = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - ) - row_partitions_maximal[partition_def.index] = RowPartition( - query, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - query = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - suppress_count=self.SUPPRESS_COUNT, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_ss[partition_def.index] = RowPartition( - query, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - gens: list[Generator] = [] - try: - with engine.connect() as connection: - partition_query_max = self.get_partition_count_query( - nullable_columns, table - ) - partition_count_max_results = ( - connection.execute(text(partition_query_max)).mappings().fetchall() - ) - count_comment = ( - "Number of rows for each combination of the columns" - f" { {nc.column.name for nc in nullable_columns} }" - f" of the table {table} being null" - ) - if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) - ) - partition_query_ss = self.get_partition_count_query( - nullable_columns, - table, - where=f"WHERE {self.SUPPRESS_COUNT} < count", - ) - partition_count_ss_results = ( - connection.execute(text(partition_query_ss)).mappings().fetchall() - ) - if self._execute_partition_queries(connection, row_partitions_ss): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - ) - ) - except sqlalchemy.exc.DatabaseError as exc: - logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) - return [] - return gens - - def _execute_partition_queries( - self, - connection: Connection, - partitions: dict[int, RowPartition], - ) -> bool: - """ - Execute the query in each partition, filling in the covariates. - - :return: True if all the partitions work, False if any of them fail. - """ - found_nonzero = False - for rp in partitions.values(): - covs = connection.execute(text(rp.query)).mappings().fetchall() - if not covs or covs.count == 0 or covs[0]["count"] is None: - rp.covariates = self.EMPTY_RESULT - else: - rp.covariates = covs - found_nonzero = True - return found_nonzero - - -class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): - """ - A generator for numeric and non-numeric columns. - - Any values could be null, the distributions of the nonnull numeric columns - depend on each other and the other non-numeric column values. - """ - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "grouped_multivariate_lognormal" - - def query_predicate(self, column: Column) -> str: - """Get the SQL expression testing if the value in this column should be used.""" - if is_numeric(column): - # x <> x + 1 ensures that x is not infinity or NaN - return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" - return f"{column.name} IS NOT NULL" - - def query_var(self, column: str) -> str: - """Get the variable or expression we are querying for this column.""" - return f"LN({column})" - - -@lru_cache(1) -def everything_factory() -> GeneratorFactory: - """Get a factory that encapsulates all the other factories.""" - return MultiGeneratorFactory( - [ - MimesisStringGeneratorFactory(), - MimesisIntegerGeneratorFactory(), - MimesisFloatGeneratorFactory(), - MimesisDateGeneratorFactory(), - MimesisDateTimeGeneratorFactory(), - MimesisTimeGeneratorFactory(), - ContinuousDistributionGeneratorFactory(), - ContinuousLogDistributionGeneratorFactory(), - ChoiceGeneratorFactory(), - ConstantGeneratorFactory(), - MultivariateNormalGeneratorFactory(), - MultivariateLogNormalGeneratorFactory(), - NullPartitionedNormalGeneratorFactory(), - NullPartitionedLogNormalGeneratorFactory(), - ] - ) diff --git a/datafaker/generators/__init__.py b/datafaker/generators/__init__.py new file mode 100644 index 00000000..c08d1203 --- /dev/null +++ b/datafaker/generators/__init__.py @@ -0,0 +1,53 @@ +"""Generators write generator function definitions and queries into config.yaml.""" + +from functools import lru_cache + +from datafaker.generators.base import ( + ConstantGeneratorFactory, + GeneratorFactory, + MultiGeneratorFactory, +) +from datafaker.generators.choice import ( + ChoiceGeneratorFactory, +) +from datafaker.generators.continuous import ( + ContinuousDistributionGeneratorFactory, + ContinuousLogDistributionGeneratorFactory, + MultivariateNormalGeneratorFactory, + MultivariateLogNormalGeneratorFactory, +) +from datafaker.generators.mimesis import ( + MimesisStringGeneratorFactory, + MimesisIntegerGeneratorFactory, + MimesisFloatGeneratorFactory, + MimesisDateGeneratorFactory, + MimesisDateTimeGeneratorFactory, + MimesisTimeGeneratorFactory, +) +from datafaker.generators.partitioned import( + NullPartitionedNormalGeneratorFactory, + NullPartitionedLogNormalGeneratorFactory, +) + + +@lru_cache(1) +def everything_factory() -> GeneratorFactory: + """Get a factory that encapsulates all the other factories.""" + return MultiGeneratorFactory( + [ + MimesisStringGeneratorFactory(), + MimesisIntegerGeneratorFactory(), + MimesisFloatGeneratorFactory(), + MimesisDateGeneratorFactory(), + MimesisDateTimeGeneratorFactory(), + MimesisTimeGeneratorFactory(), + ContinuousDistributionGeneratorFactory(), + ContinuousLogDistributionGeneratorFactory(), + ChoiceGeneratorFactory(), + ConstantGeneratorFactory(), + MultivariateNormalGeneratorFactory(), + MultivariateLogNormalGeneratorFactory(), + NullPartitionedNormalGeneratorFactory(), + NullPartitionedLogNormalGeneratorFactory(), + ] + ) diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py new file mode 100644 index 00000000..1adcb9ca --- /dev/null +++ b/datafaker/generators/base.py @@ -0,0 +1,417 @@ +"""Basic Generators and factories.""" + +import re +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any, Sequence, Union + +import mimesis +import mimesis.locales +import sqlalchemy +from sqlalchemy import Column, Engine, text +from sqlalchemy.types import Integer, Numeric, String, TypeEngine +from typing_extensions import Self + +from datafaker.base import DistributionGenerator +from datafaker.utils import T, logger + +NumericType = Union[int, float] + + +dist_gen = DistributionGenerator() +generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + + +class Generator(ABC): + """ + Random data generator. + + A generator is specific to a particular column in a particular table in + a particluar database. + + A generator knows how to fetch its summary data from the database, how to calculate + its fit (if apropriate) and which function actually does the generation. + + It also knows these summary statistics for the column it was instantiated on, + and therefore knows how to generate fake data for that column. + """ + + @abstractmethod + def function_name(self) -> str: + """Get the name of the generator function to put into df.py.""" + + def name(self) -> str: + """ + Get the name of the generator. + + Usually the same as the function name, but can be different to distinguish + between generators that have the same function but different queries. + """ + return self.function_name() + + @abstractmethod + def nominal_kwargs(self) -> dict[str, str]: + """ + Get the kwargs the generator wants to be called with. + + The values will tend to be references to something in the src-stats.yaml + file. + For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will + provide the value stored in src-stats.yaml as + SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument + to the generator function. + """ + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """ + Get the SQL clauses to add to a SELECT ... FROM {table} query. + + Will add to SRC_STATS["auto__{table}"] + For example { + "count": { + "clause": "COUNT(*)", + "comment": "number of rows in table {table}" + }, "avg_thiscolumn": { + "clause": "AVG(thiscolumn)", + "comment": "Average value of thiscolumn in table {table}" + }} + will make the clause become: + "SELECT COUNT(*) AS count, AVG(thiscolumn) AS avg_thiscolumn FROM thistable" + and this will populate SRC_STATS["auto__thistable"]["results"][0]["count"] and + SRC_STATS["auto__thistable"]["results"][0]["avg_thiscolumn"] in the src-stats.yaml file. + """ + return {} + + def custom_queries(self) -> dict[str, dict[str, str]]: + """ + Get the SQL queries to add to SRC_STATS. + + Should be used for queries that do not follow the SELECT ... FROM table format + using aggregate queries, because these should use select_aggregate_clauses. + + For example {"myquery": { + "query": "SELECT one, too AS two FROM mytable WHERE too > 1", + "comment": "big enough one and two from table mytable" + }} + will populate SRC_STATS["myquery"]["results"][0]["one"] + and SRC_STATS["myquery"]["results"][0]["two"] + in the src-stats.yaml file. + + Keys should be chosen to minimize the chances of clashing with other queries, + for example "auto__{table}__{column}__{queryname}" + """ + return {} + + @abstractmethod + def actual_kwargs(self) -> dict[str, Any]: + """ + Get the kwargs (summary statistics) this generator is instantiated with. + + This must match `nominal_kwargs` in structure. + """ + + @abstractmethod + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + + def fit(self, default: float = -1) -> float: + """ + Return a value representing how well the distribution fits the real source data. + + 0.0 means "perfectly". + Returns default if no fitness has been defined. + """ + return default + + +class PredefinedGenerator(Generator): + """Generator built from an existing config.yaml.""" + + SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") + AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") + SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') + + def _get_src_stats_mentioned(self, val: Any) -> set[str]: + if not val: + return set() + if isinstance(val, str): + ss = self.SRC_STAT_NAME_RE.match(val) + if ss: + ss_name = ss.group(1) + logger.debug("Found SRC_STATS reference %s", ss_name) + return set([ss_name]) + logger.debug("Value %s does not seem to be a SRC_STATS reference", val) + return set() + if isinstance(val, list): + return set.union(*(self._get_src_stats_mentioned(v) for v in val)) + if isinstance(val, dict): + return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) + return set() + + def __init__( + self, + table_name: str, + generator_object: Mapping[str, Any], + config: Mapping[str, Any], + ): + """ + Initialise a generator from a config.yaml. + + :param config: The entire configuration. + :param generator_object: The part of the configuration at tables.*.row_generators + """ + logger.debug( + "Creating a PredefinedGenerator %s from table %s", + generator_object["name"], + table_name, + ) + self._table_name = table_name + self._name: str = generator_object["name"] + self._kwn: dict[str, str] = generator_object.get("kwargs", {}) + self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) + # Need to deal with this somehow (or remove it from the schema) + self._argn: list[str] = generator_object.get("args", []) + self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} + self._custom_queries = {} + for sstat in config.get("src-stats", []): + name: str = sstat["name"] + dpq = sstat.get("dp-query", None) + query = sstat.get( + "query", dpq + ) # ... should we really be combining query and dp-query? + comments = sstat.get("comments", []) + if name in self._src_stats_mentioned: + logger.debug("Found a src-stats entry for %s", name) + # This query is one that this generator is interested in + sam = None if query is None else self.SELECT_AGGREGATE_RE.match(query) + # sam.group(2) is the table name from the FROM clause of the query + if sam and name == f"auto__{sam.group(2)}": + # name is auto__{table_name}, so it's a select_aggregate, + # so we split up its clauses + sacs = [ + self.AS_CLAUSE_RE.match(clause) + for clause in sam.group(1).split(",") + ] + # Work out what select_aggregate_clauses this represents + for sac in sacs: + if sac is not None: + comment = comments.pop() if comments else None + self._select_aggregate_clauses[sac.group(2)] = { + "clause": sac.group(1), + "comment": comment, + } + else: + # some other name, so must be a custom query + logger.debug("Custom query %s is '%s'", name, query) + self._custom_queries[name] = { + "query": query, + "comment": comments[0] if comments else None, + } + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" + return self._kwn + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + return self._select_aggregate_clauses + + def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" + return self._custom_queries + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + # Run the queries from nominal_kwargs + # ... + logger.error("PredefinedGenerator.actual_kwargs not implemented yet") + return {} + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + # Call the function if we can. This could be tricky... + # ... + logger.error("PredefinedGenerator.generate_data not implemented yet") + return [] + + +class GeneratorFactory(ABC): + """A factory for making generators appropriate for a database column.""" + + @abstractmethod + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + + +def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: + """Calculate the fit by comparing a pair of lists of buckets.""" + sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) + count = len(ys) + return sum_diff_squared / (count * count) + + +class Buckets: + """ + Measured buckets for a real distribution. + + Finds the real distribution of continuous data so that we can measure + the fit of generators against it. + """ + + def __init__( + self, + engine: Engine, + table_name: str, + column_name: str, + mean: float, + stddev: float, + count: int, + ): + """Initialise a Buckets object.""" + with engine.connect() as connection: + raw_buckets = connection.execute( + text( + f"SELECT COUNT({column_name}) AS f," + f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" + f" FROM {table_name} GROUP BY b" + ) + ) + self.buckets: Sequence[int] = [0] * 10 + for rb in raw_buckets: + if rb.b is not None: + bucket = min(9, max(0, int(rb.b) + 1)) + self.buckets[bucket] += rb.f / count + self.mean = mean + self.stddev = stddev + + @classmethod + def make_buckets( + cls, engine: Engine, table_name: str, column_name: str + ) -> Self | None: + """ + Construct a Buckets object. + + Calculates the mean and standard deviation of the values in the column + specified and makes ten buckets, centered on the mean and each half + a standard deviation wide (except for the end two that extend to + infinity). Each bucket will be set to the count of the number of values + in the column within that bucket. + """ + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT AVG({column_name}) AS mean," + f" STDDEV({column_name}) AS stddev," + f" COUNT({column_name}) AS count FROM {table_name}" + ) + ).first() + if result is None or result.stddev is None or getattr(result, "count") < 2: + return None + try: + buckets = cls( + engine, + table_name, + column_name, + result.mean, + result.stddev, + getattr(result, "count"), + ) + except sqlalchemy.exc.DatabaseError as exc: + logger.debug("Failed to instantiate Buckets object: %s", exc) + return None + return buckets + + def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: + """Figure out the fit from bucket counts from the generator distribution.""" + return fit_from_buckets(self.buckets, bucket_counts) + + def fit_from_values(self, values: list[float]) -> float: + """Figure out the fit from samples from the generator distribution.""" + buckets = [0] * 10 + x = self.mean - 2 * self.stddev + w = self.stddev / 2 + for v in values: + b = min(9, max(0, int((v - x) / w))) + buckets[b] += 1 + return self.fit_from_counts(buckets) + + +class MultiGeneratorFactory(GeneratorFactory): + """A composite factory.""" + + def __init__(self, factories: list[GeneratorFactory]): + """Initialise a MultiGeneratorFactory.""" + super().__init__() + self.factories = factories + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + return [ + generator + for factory in self.factories + for generator in factory.get_generators(columns, engine) + ] + + +def get_column_type(column: Column) -> TypeEngine: + """Get the type of the column, generic if possible.""" + try: + return column.type.as_generic() + except NotImplementedError: + return column.type + + +class ConstantGenerator(Generator): + """Generator that always produces the same value.""" + + def __init__(self, value: Any) -> None: + """Initialise the ConstantGenerator.""" + super().__init__() + self.value = value + self.repr = repr(value) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.constant" + + def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" + return {"value": self.repr} + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {"value": self.value} + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self.value for _ in range(count)] + + +class ConstantGeneratorFactory(GeneratorFactory): + """Just the null generator.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate for these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + if column.nullable: + return [ConstantGenerator(None)] + c_type = get_column_type(column) + if isinstance(c_type, String): + return [ConstantGenerator("")] + if isinstance(c_type, Numeric): + return [ConstantGenerator(0.0)] + if isinstance(c_type, Integer): + return [ConstantGenerator(0)] + return [] diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py new file mode 100644 index 00000000..61147a1f --- /dev/null +++ b/datafaker/generators/choice.py @@ -0,0 +1,398 @@ +"""Generator factories for making generators for choices of values.""" + +import decimal +import math +import typing +from abc import abstractmethod +from typing import Any, Sequence, Union + +from datafaker.generators.base import ( + Generator, + GeneratorFactory, + dist_gen, + fit_from_buckets, +) +from sqlalchemy import Column, CursorResult, Engine, text + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + + +def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Get a zipf distribution for a certain number of items. + + :param total: The total number of items to be distributed. + :param bins: The total number of bins to distribute the items into. + :return: A generator of the number of items in each bin, from the + largest to the smallest. + """ + basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) + bd_remaining = sum(basic_dist) + for b in basic_dist: + # yield b/bd_remaining of the `total` remaining + if bd_remaining == 0: + yield 0 + else: + x = math.floor(0.5 + total * b / bd_remaining) + bd_remaining -= x * bd_remaining / total + total -= x + yield x + + +class ChoiceGenerator(Generator): + """Base generator for all generators producing choices of items.""" + + STORE_COUNTS = False + + def __init__( + self, + table_name: str, + column_name: str, + values: list[Any], + counts: list[int], + sample_count: int | None = None, + suppress_count: int = 0, + ) -> None: + """Initialise a ChoiceGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.values = values + estimated_counts = self.get_estimated_counts(counts) + self._fit = fit_from_buckets(counts, estimated_counts) + + extra_results = "" + extra_expo = "" + extra_comment = "" + if self.STORE_COUNTS: + extra_results = f", COUNT({column_name}) AS count" + extra_expo = ", count" + extra_comment = " and their counts" + if suppress_count == 0: + if sample_count is None: + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL GROUP BY value" + f" ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name}" + ) + self._annotation = None + else: + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM" + f" (SELECT {column_name} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL" + f" ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"The values{extra_comment} that appear in column {column_name}" + f" of a random sample of {sample_count} rows of table {table_name}" + ) + self._annotation = "sampled" + else: + if sample_count is None: + self._query = ( + f"SELECT value{extra_expo} FROM" + f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" + f" FROM {table_name} WHERE {column_name} IS NOT NULL" + f" GROUP BY value ORDER BY count DESC) AS _inner" + f" WHERE {suppress_count} < count" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name} more than {suppress_count} times" + ) + self._annotation = "suppressed" + else: + self._query = ( + f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" + f" (SELECT {column_name} AS value FROM {table_name}" + f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY count DESC)" + f" AS _inner WHERE {suppress_count} < count" + ) + self._comment = ( + f"The values{extra_comment} that appear more than {suppress_count} times" + f" in column {column_name}, out of a random sample of {sample_count} rows" + f" of table {table_name}" + ) + self._annotation = "sampled and suppressed" + + @abstractmethod + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', + } + + def name(self) -> str: + """Get the name of the generator.""" + n = super().name() + if self._annotation is None: + return n + return f"{n} [{self._annotation}]" + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "a": self.values, + } + + def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" + qs = super().custom_queries() + return { + **qs, + f"auto__{self.table_name}__{self.column_name}": { + "query": self._query, + "comment": self._comment, + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default if self._fit is None else self._fit + + +class ZipfChoiceGenerator(ChoiceGenerator): + """Generator producing items in a Zipf distribution.""" + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return list(zipf_distribution(sum(counts), len(counts))) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.zipf_choice" + + def generate_data(self, count: int) -> list[float]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.zipf_choice_direct(self.values, len(self.values)) + for _ in range(count) + ] + + +def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Construct a distribution putting ``total`` items uniformly into ``bins`` bins. + + If they don't fit exactly evenly, the earlier bins will have one more + item than the later bins so the total is as required. + """ + p = total // bins + n = total % bins + for _ in range(0, n): + yield p + 1 + for _ in range(n, bins): + yield p + + +class UniformChoiceGenerator(ChoiceGenerator): + """A generator producing values, each roughly as frequently as each other.""" + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return list(uniform_distribution(sum(counts), len(counts))) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.choice" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.choice_direct(self.values) for _ in range(count)] + + +class WeightedChoiceGenerator(ChoiceGenerator): + """Choice generator that matches the source data's frequency.""" + + STORE_COUNTS = True + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return counts + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.weighted_choice" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.weighted_choice(self.values) for _ in range(count)] + + +class ValueGatherer: + """ + Gathers values from a query of values and counts. + + The query must return columns ``v`` for a value and ``f`` for the + count of how many of those values there are. + These values will be gathered into a number of properties: + ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts + in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and + ``count`` giving these values and counts. ``counts_not_suppressed``, + ``values_not_suppressed`` and ``cvs_not_suppressed`` are the + equivalents with the counts less than or equal to ``suppress_count`` + removed. + + :param suppress_count: value with a count of this or fewer will be excluded + from the suppressed values. + """ + + def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: + """Initialise a ValueGatherer.""" + values = [] # All values found + counts = [] # The number or each value + cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times + counts_not_suppressed = [] # The number for each value not suppressed + cvs_not_suppressed: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + for result in results: + c = result.f + if c != 0: + counts.append(c) + v = result.v + if isinstance(v, decimal.Decimal): + v = float(v) + values.append(v) + cvs.append({"value": v, "count": c}) + if suppress_count < c: + counts_not_suppressed.append(c) + v = result.v + if isinstance(v, decimal.Decimal): + v = float(v) + values_not_suppressed.append(v) + cvs_not_suppressed.append({"value": v, "count": c}) + self.values = values + self.counts = counts + self.cvs = cvs + self.values_not_suppressed = values_not_suppressed + self.counts_not_suppressed = counts_not_suppressed + self.cvs_not_suppressed = cvs_not_suppressed + + +class ChoiceGeneratorFactory(GeneratorFactory): + """All generators that want an average and standard deviation.""" + + SAMPLE_COUNT = MAXIMUM_CHOICES + SUPPRESS_COUNT = 5 + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + column_name = column.name + table_name = column.table.name + generators = [] + with engine.connect() as connection: + results = connection.execute( + text( + f"SELECT {column_name} AS v, COUNT({column_name})" + f" AS f FROM {table_name} GROUP BY v" + f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" + ) + ) + if results is not None and results.rowcount <= MAXIMUM_CHOICES: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), + ] + results = connection.execute( + text( + f"SELECT v, COUNT(v) AS f FROM" + f" (SELECT {column_name} as v FROM {table_name}" + f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" + f" AS _inner GROUP BY v ORDER BY f DESC" + ) + ) + if results is not None: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), + ] + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + WeightedChoiceGenerator( + table_name, + column_name, + vg.cvs, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + ] + if vg.counts_not_suppressed: + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + WeightedChoiceGenerator( + table_name=table_name, + column_name=column_name, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + ] + return generators diff --git a/datafaker/generators/continuous.py b/datafaker/generators/continuous.py new file mode 100644 index 00000000..a84d965a --- /dev/null +++ b/datafaker/generators/continuous.py @@ -0,0 +1,471 @@ +"""Generator factories for making generators of continuous distributions.""" + +from typing import Any, Sequence + +from datafaker.generators.base import ( + Buckets, + Generator, + GeneratorFactory, + NumericType, + get_column_type, +) +from sqlalchemy import Column, Engine, RowMapping, text +from sqlalchemy.types import Integer, Numeric + +from datafaker.generators.base import dist_gen +from datafaker.utils import logger + + +class ContinuousDistributionGenerator(Generator): + """Base class for generators producing continuous distributions.""" + + expected_buckets: Sequence[NumericType] = [] + + def __init__(self, table_name: str, column_name: str, buckets: Buckets): + """Initialise a ContinuousDistributionGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.buckets = buckets + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "mean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["mean__{self.column_name}"]' + ), + "sd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["stddev__{self.column_name}"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + if self.buckets is None: + return {} + return { + "mean": self.buckets.mean, + "sd": self.buckets.stddev, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + clauses = super().select_aggregate_clauses() + return { + **clauses, + f"mean__{self.column_name}": { + "clause": f"AVG({self.column_name})", + "comment": f"Mean of {self.column_name} from table {self.table_name}", + }, + f"stddev__{self.column_name}": { + "clause": f"STDDEV({self.column_name})", + "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + if self.buckets is None: + return default + return self.buckets.fit_from_counts(self.expected_buckets) + + +class GaussianGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a Gaussian (normal) distribution.""" + + expected_buckets = [ + 0.0227, + 0.0441, + 0.0918, + 0.1499, + 0.1915, + 0.1915, + 0.1499, + 0.0918, + 0.0441, + 0.0227, + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.normal" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.normal(self.buckets.mean, self.buckets.stddev) + for _ in range(count) + ] + + +class UniformGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a uniform distribution.""" + + expected_buckets = [ + 0, + 0.06698, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.06698, + 0, + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.uniform_ms" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) + for _ in range(count) + ] + + +class ContinuousDistributionGeneratorFactory(GeneratorFactory): + """All generators that want an average and standard deviation.""" + + def _get_generators_from_buckets( + self, + _engine: Engine, + table_name: str, + column_name: str, + buckets: Buckets, + ) -> Sequence[Generator]: + return [ + GaussianGenerator(table_name, column_name, buckets), + UniformGenerator(table_name, column_name, buckets), + ] + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + column_name = column.name + table_name = column.table.name + buckets = Buckets.make_buckets(engine, table_name, column_name) + if buckets is None: + return [] + return self._get_generators_from_buckets( + engine, table_name, column_name, buckets + ) + + +class LogNormalGenerator(Generator): + """Generator producing numbers in a log-normal distribution.""" + + # TODO: figure out the real buckets here (this was from a random sample in R) + expected_buckets = [ + 0, + 0, + 0, + 0.28627, + 0.40607, + 0.14937, + 0.06735, + 0.03492, + 0.01918, + 0.03684, + ] + + def __init__( + self, + table_name: str, + column_name: str, + buckets: Buckets, + logmean: float, + logstddev: float, + ): + """Initialise a LogNormalGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.buckets = buckets + self.logmean = logmean + self.logstddev = logstddev + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.lognormal" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "logmean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logmean__{self.column_name}"]' + ), + "logsd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logstddev__{self.column_name}"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "logmean": self.logmean, + "logsd": self.logstddev, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + clauses = super().select_aggregate_clauses() + return { + **clauses, + f"logmean__{self.column_name}": { + "clause": ( + f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" + " ELSE NULL END)" + ), + "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", + }, + f"logstddev__{self.column_name}": { + "clause": ( + f"STDDEV(CASE WHEN 0<{self.column_name}" + f" THEN LN({self.column_name}) ELSE NULL END)" + ), + "comment": ( + f"Standard deviation of logs of {self.column_name}" + f" from table {self.table_name}" + ), + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + if self.buckets is None: + return default + return self.buckets.fit_from_counts(self.expected_buckets) + + +class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): + """All generators that want an average and standard deviation of log data.""" + + def _get_generators_from_buckets( + self, + engine: Engine, + table_name: str, + column_name: str, + buckets: Buckets, + ) -> Sequence[Generator]: + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" + " ELSE NULL END) AS logmean," + f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" + f" AS logstddev FROM {table_name}" + ) + ).first() + if result is None or result.logstddev is None: + return [] + return [ + LogNormalGenerator( + table_name, + column_name, + buckets, + float(result.logmean), + float(result.logstddev), + ) + ] + + +class MultivariateNormalGenerator(Generator): + """Generator of multiple values drawn from a multivariate normal distribution.""" + + def __init__( + self, + table_name: str, + column_names: list[str], + query: str, + covariates: RowMapping, + function_name: str, + ) -> None: + """Initialise a MultivariateNormalGenerator.""" + self._table = table_name + self._columns = column_names + self._query = query + self._covariates = covariates + self._function_name = function_name + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen." + self._function_name + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', + } + + def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" + cols = ", ".join(self._columns) + return { + f"auto__cov__{self._table}": { + "comment": ( + f"Means and covariate matrix for the columns {cols}," + " so that we can produce the relatedness between these in the fake data." + ), + "query": self._query, + } + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {"cov": self._covariates} + + def generate_data(self, count: int) -> list[Any]: + """Generate 'count' random data points for this column.""" + return [ + getattr(dist_gen, self._function_name)(self._covariates) + for _ in range(count) + ] + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default + + +class MultivariateNormalGeneratorFactory(GeneratorFactory): + """Normal distribution generator factory.""" + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "multivariate_normal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" + return column.name + " IS NOT NULL" + + def query_var(self, column: str) -> str: + """Get the SQL expression of the value to query for this column.""" + return column + + def query( + self, + table: str, + columns: list[Column], + predicates: list[str] = [], + group_by_clause: str = "", + constant_clauses: str = "", + constants: str = "", + suppress_count: int = 1, + sample_count: int | None = None, + ) -> str: + """ + Get a query for the basics for multivariate normal/lognormal parameters. + + :param table: The name of the table to be queried. + :param columns: The columns in the multivariate distribution. + :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. + :param group_by_clause: Any GROUP BY clause (starting with " GROUP BY " if not ""). + :param constant_clauses: Extra output columns in the outer SELECT clause, such + as ", _q.column_one AS k1, _q.column_two AS k2". Note the initial comma. + :param constants: Extra output columns in the inner SELECT clause. Used to + deliver columns to the outer select, such as ", column_one, column_two". + Note the initial comma. + :param suppress_count: a group smaller than this will be suppressed. + :param sample_count: this many samples will be taken from each partition. + """ + preds = [self.query_predicate(col) for col in columns] + predicates + where = " WHERE " + " AND ".join(preds) if preds else "" + avgs = "".join( + f", AVG({self.query_var(col.name)}) AS m{i}" + for i, col in enumerate(columns) + ) + multiples = "".join( + f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" + for iy, coly in enumerate(columns) + for ix, colx in enumerate(columns[: iy + 1]) + ) + means = "".join(f", _q.m{i}" for i in range(len(columns))) + covs = "".join( + ( + f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" + f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ) + for iy in range(len(columns)) + for ix in range(iy + 1) + ) + if sample_count is None: + subquery = table + where + else: + subquery = ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) + # if there are any numeric columns we need at least# + # two rows to make any (co)variances at all + suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" + return ( + f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" + f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" + f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" + ) + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators for these columns.""" + # For the case of one column we'll use GaussianGenerator + if len(columns) < 2: + return [] + # All columns must be numeric + for c in columns: + ct = get_column_type(c) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + column_names = [c.name for c in columns] + table = columns[0].table.name + query = self.query(table, columns) + with engine.connect() as connection: + try: + covariates = connection.execute(text(query)).mappings().first() + except Exception as e: + logger.debug("SQL query %s failed with error %s", query, e) + return [] + if not covariates or covariates["c0_0"] is None: + return [] + return [ + MultivariateNormalGenerator( + table, + column_names, + query, + covariates, + self.function_name(), + ) + ] + + +class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Multivariate lognormal generator factory.""" + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "multivariate_lognormal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" + return f"COALESCE(0 < {column.name}, FALSE)" + + def query_var(self, column: str) -> str: + """Get the expression to query for, for this column.""" + return f"LN({column})" diff --git a/datafaker/generators/mimesis.py b/datafaker/generators/mimesis.py new file mode 100644 index 00000000..8d031aba --- /dev/null +++ b/datafaker/generators/mimesis.py @@ -0,0 +1,418 @@ +"""Generators using Mimesis.""" + +from typing import Any, Callable, Sequence, Union + +import mimesis +import mimesis.locales +from datafaker.generators.base import ( + Buckets, + Generator, + GeneratorFactory, + get_column_type, +) +from sqlalchemy import Column, Engine, text +from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time + +from datafaker.base import DistributionGenerator + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + +dist_gen = DistributionGenerator() +generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + + +class MimesisGeneratorBase(Generator): + """Base class for a generator using Mimesis.""" + + def __init__( + self, + function_name: str, + ): + """ + Initialise a generator that uses Mimesis. + + :param function_name: is relative to 'generic', for example 'person.name'. + """ + super().__init__() + f = generic + for part in function_name.split("."): + if not hasattr(f, part): + raise Exception( + f"Mimesis does not have a function {function_name}: {part} not found" + ) + f = getattr(f, part) + if not callable(f): + raise Exception( + f"Mimesis object {function_name} is not a callable," + " so cannot be used as a generator" + ) + self._name = "generic." + function_name + self._generator_function = f + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self._generator_function() for _ in range(count)] + + +class MimesisGenerator(MimesisGeneratorBase): + """A generator using Mimesis.""" + + def __init__( + self, + function_name: str, + value_fn: Callable[[Any], float] | None = None, + buckets: Buckets | None = None, + ): + """ + Initialise a generator using Mimesis. + + :param function_name: is relative to 'generic', for example 'person.name'. + :param value_fn: Function to convert generator output to floats, if needed. The values + thus produced are compared against the buckets to estimate the fit. + :param buckets: The distribution of string lengths in the real data. If this is None + then the fit method will return None. + """ + super().__init__(function_name) + if buckets is None: + self._fit = None + return + samples = self.generate_data(400) + if value_fn: + samples = [value_fn(s) for s in samples] + self._fit = buckets.fit_from_values(samples) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return {} + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {} + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default if self._fit is None else self._fit + + +class MimesisGeneratorTruncated(MimesisGenerator): + """A string generator using Mimesis that must fit within a certain number of characters.""" + + def __init__( + self, + function_name: str, + length: int, + value_fn: Callable[[Any], float] | None = None, + buckets: Buckets | None = None, + ): + """Initialise a MimesisGeneratorTruncated.""" + self._length = length + super().__init__(function_name, value_fn, buckets) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.truncated_string" + + def name(self) -> str: + """Get the name of the generator.""" + return f"{self._name} [truncated to {self._length}]" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "subgen_fn": self._name, + "params": {}, + "length": self._length, + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "subgen_fn": self._name, + "params": {}, + "length": self._length, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self._generator_function()[: self._length] for _ in range(count)] + + +class MimesisDateTimeGenerator(MimesisGeneratorBase): + """DateTime generator using Mimesis.""" + + def __init__( + self, + column: Column, + function_name: str, + min_year: str, + max_year: str, + start: int, + end: int, + ) -> None: + """ + Initialise a MimesisDateTimeGenerator. + + :param column: The column to generate into + :param function_name: The name of the mimesis function + :param min_year: SQL expression extracting the minimum year + :param min_year: SQL expression extracting the maximum year + :param start: The actual first year found + :param end: The actual last year found + """ + super().__init__(function_name) + self._column = column + self._max_year = max_year + self._min_year = min_year + self._start = start + self._end = end + + @classmethod + def make_singleton( + cls, column: Column, engine: Engine, function_name: str + ) -> Sequence[Generator]: + """Make the appropriate generation configuration for this column.""" + extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" + max_year = f"MAX({extract_year})" + min_year = f"MIN({extract_year})" + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" + ) + ).first() + if result is None or result.start is None or result.end is None: + return [] + return [ + MimesisDateTimeGenerator( + column, + function_name, + min_year, + max_year, + int(result.start), + int(result.end), + ) + ] + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "start": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__start"]' + ), + "end": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__end"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "start": self._start, + "end": self._end, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + return { + f"{self._column.name}__start": { + "clause": self._min_year, + "comment": ( + f"Earliest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), + }, + f"{self._column.name}__end": { + "clause": self._max_year, + "comment": ( + f"Latest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), + }, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + self._generator_function(start=self._start, end=self._end) + for _ in range(count) + ] + + +class MimesisStringGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return strings.""" + + GENERATOR_NAMES = [ + "address.calling_code", + "address.city", + "address.continent", + "address.country", + "address.country_code", + "address.postal_code", + "address.province", + "address.street_number", + "address.street_name", + "address.street_suffix", + "person.blood_type", + "person.email", + "person.first_name", + "person.last_name", + "person.full_name", + "person.gender", + "person.language", + "person.nationality", + "person.occupation", + "person.password", + "person.title", + "person.university", + "person.username", + "person.worldview", + "text.answer", + "text.color", + "text.level", + "text.quote", + "text.sentence", + "text.text", + "text.word", + ] + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + column_type = get_column_type(column) + if not isinstance(column_type, String): + return [] + try: + buckets = Buckets.make_buckets( + engine, + column.table.name, + f"LENGTH({column.name})", + ) + fitness_fn = len + except Exception: + # Some column types that appear to be strings (such as enums) + # cannot have their lengths measured. In this case we cannot + # detect fitness using lengths. + buckets = None + fitness_fn = None + length = column_type.length + if length: + return list( + map( + lambda gen: MimesisGeneratorTruncated( + gen, length, fitness_fn, buckets + ), + self.GENERATOR_NAMES, + ) + ) + return list( + map( + lambda gen: MimesisGenerator(gen, fitness_fn, buckets), + self.GENERATOR_NAMES, + ) + ) + + +class MimesisFloatGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return floating point numbers.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + if not isinstance(get_column_type(column), Numeric): + return [] + return list( + map( + MimesisGenerator, + [ + "person.height", + ], + ) + ) + + +class MimesisDateGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return dates.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Date): + return [] + return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") + + +class MimesisDateTimeGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return datetimes.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, DateTime): + return [] + return MimesisDateTimeGenerator.make_singleton( + column, engine, "datetime.datetime" + ) + + +class MimesisTimeGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return times.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Time): + return [] + return [MimesisGenerator("datetime.time")] + + +class MimesisIntegerGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return integers.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + return [MimesisGenerator("person.weight")] diff --git a/datafaker/generators/partitioned.py b/datafaker/generators/partitioned.py new file mode 100644 index 00000000..395f2614 --- /dev/null +++ b/datafaker/generators/partitioned.py @@ -0,0 +1,514 @@ +"""Powerful generators for numbers, choices and related missingness.""" + +from dataclasses import dataclass +from itertools import chain, combinations +from typing import Any, Iterable, Sequence, Union + +import sqlalchemy +from datafaker.generators.base import ( + Generator, + dist_gen, + get_column_type, +) +from datafaker.generators.continuous import ( + MultivariateNormalGeneratorFactory, +) +from sqlalchemy import Column, Connection, Engine, RowMapping, text +from sqlalchemy.types import Integer, Numeric + +from datafaker.utils import T, logger + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + + +def text_list(items: Iterable[str]) -> str: + """Concatenate the items with commas and one "and".""" + item_i = iter(items) + try: + last_item = next(item_i) + except StopIteration: + return "" + try: + so_far = next(item_i) + except StopIteration: + return last_item + for item in item_i: + so_far += ", " + last_item + last_item = item + return so_far + " and " + last_item + + +@dataclass +class RowPartition: + """A partition where all the rows have the same pattern of NULLs.""" + + query: str + # list of numeric columns + included_numeric: list[Column] + # map of indices to column names that are being grouped by. + # The indices are indices of where they need to be inserted into + # the generator outputs. + included_choice: dict[int, str] + # map of column names to clause that defines the partition + # such as "mycolumn IS NULL" + excluded_columns: dict[str, str] + # map of constant outputs that need to be inserted into the + # list of included column values (so once the generator has + # been run and the included_choice values have been + # added): {index: value} + constant_outputs: dict[int, Any] + # The actual covariates from the source database + covariates: Sequence[RowMapping] + + def comment(self) -> str: + """Make an appropriate comment for this partition.""" + caveat = "" + if self.included_choice: + caveat = f" (for each possible value of {text_list(self.included_choice.values())})" + if not self.included_numeric: + return f"Number of rows for which {text_list(self.excluded_columns.values())}{caveat}" + if not self.excluded_columns: + where = "" + else: + where = f" where {text_list(self.excluded_columns.values())}" + if len(self.included_numeric) == 1: + return ( + f"Mean and variance for column {self.included_numeric[0].name}{where}." + ) + return ( + "Means and covariate matrix for the columns " + f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" + " produce the relatedness between these in the fake data." + ) + + +class NullPartitionedNormalGenerator(Generator): + """ + A generator of mixed numeric and non-numeric data. + + Generates data that matches the source data in + missingness, choice of non-numeric data and numeric + data. + + For the numeric data to be generated, samples of rows for each + combination of non-numeric values and missingness. If any such + combination has only one line in the source data (or sample of + the source data if sampling), it will not be generated as a + covariate matrix cannot be generated from one source row + (although if the data is all non-numeric values and nulls, single + rows are used because no covariate matrix is required for this). + """ + + def __init__( + self, + query_name: str, + partitions: dict[int, RowPartition], + function_name: str = "grouped_multivariate_lognormal", + name_suffix: str | None = None, + partition_count_query: str | None = None, + partition_counts: Iterable[RowMapping] = [], + partition_count_comment: str | None = None, + ): + """Initialise a NullPartitionedNormalGenerator.""" + self._query_name = query_name + self._partitions = partitions + self._function_name = function_name + self._partition_count_query = partition_count_query + self._partition_counts = [dict(pc) for pc in partition_counts] + self._partition_count_comment = partition_count_comment + if name_suffix: + self._name = f"null-partitioned {function_name} [{name_suffix}]" + else: + self._name = f"null-partitioned {function_name}" + + def name(self) -> str: + """Get the name of the generator.""" + return self._name + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.alternatives" + + def _nominal_kwargs_with_combinations( + self, index: int, partition: RowPartition + ) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml`` for a single partition.""" + count = ( + 'sum(r["count"] for r in' + f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + ) + if not partition.included_numeric and not partition.included_choice: + return { + "count": count, + "name": '"constant"', + "params": {"value": [None] * len(partition.constant_outputs)}, + } + covariates = { + "covs": f'SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"]' + } + if not partition.constant_outputs: + return { + "count": count, + "name": f'"{self._function_name}"', + "params": covariates, + } + return { + "count": count, + "name": '"with_constants_at"', + "params": { + "constants_at": partition.constant_outputs, + "subgen": f'"{self._function_name}"', + "params": covariates, + }, + } + + def _count_query_name(self) -> str: + return f"auto__cov__{self._query_name}__counts" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "alternative_configs": [ + self._nominal_kwargs_with_combinations(index, self._partitions[index]) + for index in range(len(self._partitions)) + ], + "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', + } + + def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" + partitions = { + f"auto__cov__{self._query_name}__alt_{index}": { + "comment": partition.comment(), + "query": partition.query, + } + for index, partition in self._partitions.items() + } + if not self._partition_count_query: + return partitions + return { + self._count_query_name(): { + "comment": self._partition_count_comment, + "query": self._partition_count_query, + }, + **partitions, + } + + def _actual_kwargs_with_combinations( + self, partition: RowPartition + ) -> dict[str, Any]: + count = sum(row["count"] for row in partition.covariates) + if not partition.included_numeric and not partition.included_choice: + return { + "count": count, + "name": "constant", + "params": {"value": [None] * len(partition.excluded_columns)}, + } + covariates = { + "covs": partition.covariates, + } + if not partition.constant_outputs: + return { + "count": count, + "name": self._function_name, + "params": covariates, + } + return { + "count": count, + "name": "with_constants_at", + "params": { + "constants_at": partition.constant_outputs, + "subgen": self._function_name, + "params": covariates, + }, + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "alternative_configs": [ + self._actual_kwargs_with_combinations(self._partitions[index]) + for index in range(len(self._partitions)) + ], + "counts": self._partition_counts, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate 'count' random data points for this column.""" + kwargs = self.actual_kwargs() + return [dist_gen.alternatives(**kwargs) for _ in range(count)] + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default + + +def is_numeric(col: Column) -> bool: + """Test if this column stores a numeric value.""" + ct = get_column_type(col) + return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys + + +def powerset(xs: list[T]) -> Iterable[Iterable[T]]: + """Get a list of all sublists of ``input``.""" + return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) + + +@dataclass +class NullableColumn: + """A reference to a nullable column whose nullability is part of a partitioning.""" + + column: Column + # The bit (power of two) of the number of the partition in the partition sizes list + bitmask: int + + +class NullPatternPartition: + """Get the definition of a partition (in other words, what makes it not another partition).""" + + def __init__( + self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] + ): + """Initialise a pattern of nulls which can be queried for.""" + self.index = sum(nc.bitmask for nc in partition_nonnulls) + nonnull_columns = {nc.column.name for nc in partition_nonnulls} + self.included_numeric: list[Column] = [] + self.included_choice: dict[int, str] = {} + self.group_by_clause = "" + self.constant_clauses = "" + self.constants = "" + self.excluded: dict[str, str] = {} + self.predicates: list[str] = [] + self.nones: dict[int, None] = {} + for col_index, column in enumerate(columns): + col_name = column.name + if col_name in nonnull_columns or not column.nullable: + if is_numeric(column): + self.included_numeric.append(column) + else: + index = len(self.included_numeric) + len(self.included_choice) + self.included_choice[index] = col_name + if self.group_by_clause: + self.group_by_clause += ", " + col_name + else: + self.group_by_clause = " GROUP BY " + col_name + self.constant_clauses += f", _q.{col_name} AS k{index}" + self.constants += ", " + col_name + else: + self.excluded[col_name] = f"{col_name} IS NULL" + self.predicates.append(f"{col_name} IS NULL") + self.nones[col_index] = None + + +class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Produces null partitioned generators, for complex interdependent data.""" + + SAMPLE_COUNT = MAXIMUM_CHOICES + SUPPRESS_COUNT = 5 + EMPTY_RESULT = [ + RowMapping( + parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), + processors=None, + key_to_index={"count": 0}, + data=(0,), + ) + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "grouped_multivariate_normal" + + def query_predicate(self, column: Column) -> str: + """Get a SQL expression that is true when ``column`` is available for analysis.""" + if is_numeric(column): + # x <> x + 1 ensures that x is not infinity or NaN + return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" + return f"{column.name} IS NOT NULL" + + def query_var(self, column: str) -> str: + """Return the expression we are querying for in this column.""" + return column + + def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: + """Get a list of nullable columns together with bitmasks.""" + out: list[NullableColumn] = [] + for col in columns: + if col.nullable: + out.append( + NullableColumn( + column=col, + bitmask=2 ** len(out), + ) + ) + return out + + def get_partition_count_query( + self, ncs: list[NullableColumn], table: str, where: str | None = None + ) -> str: + """ + Get a SQL expression returning columns ``count`` and ``index``. + + Each row returned represents one of the null pattern partitions. + ``index`` is the bitmask of all those nullable columns that are not null for + this partition, and ``count`` is the total number of rows in this partition. + """ + index_exp = " + ".join( + f"CASE WHEN {self.query_predicate(nc.column)} THEN {nc.bitmask} ELSE 0 END" + for nc in ncs + ) + if where is None: + return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' + return ( + 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' + f' {index_exp} AS "index"' + f' FROM {table} GROUP BY "index") AS _q {where}' + ) + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get any appropriate generators for these columns.""" + if len(columns) < 2: + return [] + nullable_columns = self.get_nullable_columns(columns) + if not nullable_columns: + return [] + table = columns[0].table.name + query_name = f"{table}__{columns[0].name}" + # Partitions for minimal suppression and no sampling + row_partitions_maximal: dict[int, RowPartition] = {} + # Partitions for normal suppression and severe sampling + row_partitions_ss: dict[int, RowPartition] = {} + for partition_nonnulls in powerset(nullable_columns): + partition_def = NullPatternPartition(columns, partition_nonnulls) + query = self.query( + table=table, + columns=partition_def.included_numeric, + predicates=partition_def.predicates, + group_by_clause=partition_def.group_by_clause, + constants=partition_def.constants, + constant_clauses=partition_def.constant_clauses, + ) + row_partitions_maximal[partition_def.index] = RowPartition( + query, + partition_def.included_numeric, + partition_def.included_choice, + partition_def.excluded, + partition_def.nones, + [], + ) + query = self.query( + table=table, + columns=partition_def.included_numeric, + predicates=partition_def.predicates, + group_by_clause=partition_def.group_by_clause, + constants=partition_def.constants, + constant_clauses=partition_def.constant_clauses, + suppress_count=self.SUPPRESS_COUNT, + sample_count=self.SAMPLE_COUNT, + ) + row_partitions_ss[partition_def.index] = RowPartition( + query, + partition_def.included_numeric, + partition_def.included_choice, + partition_def.excluded, + partition_def.nones, + [], + ) + gens: list[Generator] = [] + try: + with engine.connect() as connection: + partition_query_max = self.get_partition_count_query( + nullable_columns, table + ) + partition_count_max_results = ( + connection.execute(text(partition_query_max)).mappings().fetchall() + ) + count_comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table} being null" + ) + if self._execute_partition_queries(connection, row_partitions_maximal): + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_maximal, + self.function_name(), + partition_count_query=partition_query_max, + partition_counts=partition_count_max_results, + partition_count_comment=count_comment, + ) + ) + partition_query_ss = self.get_partition_count_query( + nullable_columns, + table, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + ) + partition_count_ss_results = ( + connection.execute(text(partition_query_ss)).mappings().fetchall() + ) + if self._execute_partition_queries(connection, row_partitions_ss): + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_ss, + self.function_name(), + name_suffix="sampled and suppressed", + partition_count_query=partition_query_ss, + partition_counts=partition_count_ss_results, + partition_count_comment=count_comment, + ) + ) + except sqlalchemy.exc.DatabaseError as exc: + logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) + return [] + return gens + + def _execute_partition_queries( + self, + connection: Connection, + partitions: dict[int, RowPartition], + ) -> bool: + """ + Execute the query in each partition, filling in the covariates. + + :return: True if all the partitions work, False if any of them fail. + """ + found_nonzero = False + for rp in partitions.values(): + covs = connection.execute(text(rp.query)).mappings().fetchall() + if not covs or covs.count == 0 or covs[0]["count"] is None: + rp.covariates = self.EMPTY_RESULT + else: + rp.covariates = covs + found_nonzero = True + return found_nonzero + + +class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): + """ + A generator for numeric and non-numeric columns. + + Any values could be null, the distributions of the nonnull numeric columns + depend on each other and the other non-numeric column values. + """ + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "grouped_multivariate_lognormal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression testing if the value in this column should be used.""" + if is_numeric(column): + # x <> x + 1 ensures that x is not infinity or NaN + return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" + return f"{column.name} IS NOT NULL" + + def query_var(self, column: str) -> str: + """Get the variable or expression we are querying for this column.""" + return f"LN({column})" diff --git a/datafaker/interactive.py b/datafaker/interactive.py deleted file mode 100644 index 17efa199..00000000 --- a/datafaker/interactive.py +++ /dev/null @@ -1,2175 +0,0 @@ -"""Interactive configuration commands.""" -import cmd -import csv -import functools -import re -from abc import ABC, abstractmethod -from collections.abc import Mapping, MutableMapping, Sequence -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from types import TracebackType -from typing import Any, Callable, Iterable, Optional, Type, cast - -import sqlalchemy -from prettytable import PrettyTable -from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table -from typing_extensions import Self - -from datafaker.generators import Generator, PredefinedGenerator, everything_factory -from datafaker.utils import ( - T, - create_db_engine, - fk_refers_to_ignored_table, - get_sync_engine, - logger, - primary_private_fks, - table_is_private, -) - -# Monkey patch pyreadline3 v3.5 so that it works with Python 3.13 -# Windows users can install pyreadline3 to get tab completion working. -# See https://github.com/pyreadline3/pyreadline3/issues/37 -try: - import readline - - if not hasattr(readline, "backend"): - setattr(readline, "backend", "readline") -except: - pass - - -def or_default(v: T | None, d: T) -> T: - """Return v if it isn't None, otherwise d.""" - return d if v is None else v - - -class TableType(Enum): - """Types of table to be configured.""" - - GENERATE = "generate" - IGNORE = "ignore" - VOCABULARY = "vocabulary" - PRIVATE = "private" - EMPTY = "empty" - - -TYPE_LETTER = { - TableType.GENERATE: "G", - TableType.IGNORE: "I", - TableType.VOCABULARY: "V", - TableType.PRIVATE: "P", - TableType.EMPTY: "e", -} - -TYPE_PROMPT = { - TableType.GENERATE: "(table: {}) ", - TableType.IGNORE: "(table: {} (ignore)) ", - TableType.VOCABULARY: "(table: {} (vocab)) ", - TableType.PRIVATE: "(table: {} (private)) ", - TableType.EMPTY: "(table: {} (empty))", -} - - -@dataclass -class TableEntry: - """Base class for table entries for interactive commands.""" - - name: str # name of the table - - -class AskSaveCmd(cmd.Cmd): - """Interactive shell for whether to save and quit.""" - - intro = "Do you want to save this configuration?" - prompt = "(yes/no/cancel) " - file = None - - def __init__(self) -> None: - """Initialise a save command.""" - super().__init__() - self.result = "" - - def do_yes(self, _arg: str) -> bool: - """Save the new config.yaml.""" - self.result = "yes" - return True - - def do_no(self, _arg: str) -> bool: - """Exit without saving.""" - self.result = "no" - return True - - def do_cancel(self, _arg: str) -> bool: - """Do not exit.""" - self.result = "cancel" - return True - - -def fk_column_name(fk: ForeignKey) -> str: - """Display name for a foreign key.""" - if fk_refers_to_ignored_table(fk): - return f"{fk.target_fullname} (ignored)" - return str(fk.target_fullname) - - -class DbCmd(ABC, cmd.Cmd): - """Base class for interactive configuration commands.""" - - INFO_NO_MORE_TABLES = "There are no more tables" - ERROR_ALREADY_AT_START = "Error: Already at the start" - ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" - ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" - ROW_COUNT_MSG = "Total row count: {}" - - @abstractmethod - def make_table_entry( - self, table_name: str, table_config: Mapping - ) -> TableEntry | None: - """ - Make a table entry suitable for this interactive command. - - :param name: The name of the table to make an entry for. - :param table_config: The part of the ``config.yaml`` referring to this table. - :return: The table entry or None if this table should not be interacted with. - """ - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - ): - """Initialise a DbCmd.""" - super().__init__() - self.config: MutableMapping[str, Any] = config - self.metadata = metadata - self._table_entries: list[TableEntry] = [] - tables_config: MutableMapping = config.get("tables", {}) - if not isinstance(tables_config, MutableMapping): - tables_config = {} - for name in metadata.tables.keys(): - table_config = tables_config.get(name, {}) - if not isinstance(table_config, MutableMapping): - table_config = {} - entry = self.make_table_entry(name, table_config) - if entry is not None: - self._table_entries.append(entry) - self.table_index = 0 - self.engine = create_db_engine(src_dsn, schema_name=src_schema) - - @property - def sync_engine(self) -> Engine: - """Get the synchronous version of the engine.""" - return get_sync_engine(self.engine) - - def __enter__(self) -> Self: - """Enter a ``with`` statement.""" - return self - - def __exit__( - self, - _exc_type: Optional[Type[BaseException]], - _exc_val: Optional[BaseException], - _exc_tb: Optional[TracebackType], - ) -> None: - """Dispose of this object.""" - self.engine.dispose() - - def print(self, text: str, *args: Any, **kwargs: Any) -> None: - """Print text, formatted with positional and keyword arguments.""" - print(text.format(*args, **kwargs)) - - def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: - """ - Print a table. - - :param headings: List of headings for the table. - :param rows: List of rows of values. - """ - output = PrettyTable() - output.field_names = headings - for row in rows: - output.add_row(row) - print(output) - - def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: - """ - Print a table. - - :param columns: Dict of column names to the values in the column. - """ - output = PrettyTable() - row_count = max([len(col) for col in columns.values()]) - for field_name, data in columns.items(): - output.add_column(field_name, data + [None] * (row_count - len(data))) - print(output) - - def print_results(self, result: sqlalchemy.CursorResult) -> None: - """Print the rows resulting from a database query.""" - self.print_table(list(result.keys()), [list(row) for row in result.all()]) - - def ask_save(self) -> str: - """ - Ask the user if they want to save. - - :return: ``yes``, ``no`` or ``cancel``. - """ - ask = AskSaveCmd() - ask.cmdloop() - return ask.result - - @abstractmethod - def set_prompt(self) -> None: - """Set the prompt according to the current state.""" - ... - - def _set_table_index(self, index: int) -> bool: - """ - Move to a different table. - - :param index: Index of the table to move to. - :return: True if there is a table with such an index to move to. - """ - if 0 <= index and index < len(self._table_entries): - self.table_index = index - self.set_prompt() - return True - return False - - def next_table(self, report: str = "No more tables") -> bool: - """ - Move to the next table. - - :param report: The text to print if there is no next table. - :return: True if there is another table to move to. - """ - if not self._set_table_index(self.table_index + 1): - self.print(report) - return False - return True - - def table_name(self) -> str: - """Get the name of the current table.""" - return str(self._table_entries[self.table_index].name) - - def table_metadata(self) -> Table: - """Get the metadata of the current table.""" - return self.metadata.tables[self.table_name()] - - def _get_column_names(self) -> list[str]: - """Get the names of the current columns.""" - return [col.name for col in self.table_metadata().columns] - - def report_columns(self) -> None: - """Print information about the current columns.""" - self.print_table( - ["name", "type", "primary", "nullable", "foreign key"], - [ - [ - name, - str(col.type), - col.primary_key, - col.nullable, - ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), - ] - for name, col in self.table_metadata().columns.items() - ], - ) - - def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: - """Get the configuration of the named table.""" - ts = self.config.get("tables", None) - if not isinstance(ts, MutableMapping): - return {} - t = ts.get(table_name) - return t if isinstance(t, MutableMapping) else {} - - def set_table_config( - self, table_name: str, config: MutableMapping[str, Any] - ) -> None: - """Set the configuration of the named table.""" - ts = self.config.get("tables", None) - if not isinstance(ts, MutableMapping): - self.config["tables"] = {table_name: config} - return - ts[table_name] = config - - def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: - """Remove all source stats with the given prefix from the configuration.""" - src_stats = self.config.get("src-stats", []) - new_src_stats = [] - for stat in src_stats: - if not stat.get("name", "").startswith(prefix): - new_src_stats.append(stat) - self.config["src-stats"] = new_src_stats - return new_src_stats - - def get_nonnull_columns(self, table_name: str) -> list[str]: - """Get the names of the nullable columns in the named table.""" - metadata_table = self.metadata.tables[table_name] - return [ - str(name) - for name, column in metadata_table.columns.items() - if column.nullable - ] - - def find_entry_index_by_table_name(self, table_name: str) -> int | None: - """Get the index of the table entry of the named table.""" - return next( - ( - i - for i, entry in enumerate(self._table_entries) - if entry.name == table_name - ), - None, - ) - - def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: - """Get the table entry of the named table.""" - for e in self._table_entries: - if e.name == table_name: - return e - return None - - def do_counts(self, _arg: str) -> None: - """Report the column names with the counts of nulls in them.""" - if len(self._table_entries) <= self.table_index: - return - table_name = self.table_name() - nonnull_columns = self.get_nonnull_columns(table_name) - colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" - ) - ).first() - if result is None: - self.print("Could not count rows in table {0}", table_name) - return - row_count = result.row_count - self.print(self.ROW_COUNT_MSG, row_count) - self.print_table( - ["Column", "NULL count"], - [ - [name, row_count - count] - for name, count in result._mapping.items() - if name != "row_count" - ], - ) - - def do_select(self, arg: str) -> None: - """Run a select query over the database and show the first 50 results.""" - max_select_rows = 50 - with self.sync_engine.connect() as connection: - try: - result = connection.execute(sqlalchemy.text("SELECT " + arg)) - except sqlalchemy.exc.DatabaseError as exc: - self.print("Failed to execute: {}", exc) - return - row_count = result.rowcount - self.print(self.ROW_COUNT_MSG, row_count) - if 50 < row_count: - self.print("Showing the first {} rows", max_select_rows) - fields = list(result.keys()) - rows = [row._tuple() for row in result.fetchmany(max_select_rows)] - self.print_table(fields, rows) - - def do_peek(self, arg: str) -> None: - """ - View some data from the current table. - - Use 'peek col1 col2 col3' to see a sample of values from - columns col1, col2 and col3 in the current table. - Use 'peek' to see a sample of the current column(s). - Rows that are enitrely null are suppressed. - """ - max_peek_rows = 25 - if len(self._table_entries) <= self.table_index: - return - table_name = self.table_name() - col_names = arg.split() - if not col_names: - col_names = self._get_column_names() - nonnulls = [cn + " IS NOT NULL" for cn in col_names] - with self.sync_engine.connect() as connection: - cols = (",".join(col_names),) - where = ("WHERE" if nonnulls else "",) - nonnull = (" OR ".join(nonnulls),) - query = sqlalchemy.text( - f"SELECT {cols} FROM {table_name} {where} {nonnull}" - f" ORDER BY RANDOM() LIMIT {max_peek_rows}" - ) - try: - result = connection.execute(query) - except Exception as exc: - self.print(f'SQL query "{query}" caused exception {exc}') - return - rows = [row._tuple() for row in result.fetchmany(max_peek_rows)] - self.print_table(list(result.keys()), rows) - - def complete_peek( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Completions for the ``peek`` command.""" - if len(self._table_entries) <= self.table_index: - return [] - return [ - col for col in self.table_metadata().columns.keys() if col.startswith(text) - ] - - -@dataclass -class TableCmdTableEntry(TableEntry): - """Table entry for the table command shell.""" - - old_type: TableType - new_type: TableType - - -class TableCmd(DbCmd): - """Command shell allowing the user to set the type of each table.""" - - intro = ( - "Interactive table configuration (ignore," - " vocabulary, private, generate or empty). Type ? for help.\n" - ) - doc_leader = """Use the commands 'ignore', 'vocabulary', -'private', 'empty' or 'generate' to set the table's type. Use 'next' or -'previous' to change table. Use 'tables' and 'columns' for -information about the database. Use 'data', 'peek', 'select' or -'count' to see some data contained in the current table. Use 'quit' -to exit this program.""" - prompt = "(tableconf) " - file = None - WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( - "Vocabulary table {0} references non-vocabulary table {1}" - ) - WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( - "Empty table {1} referenced from non-empty table {0}. {1} will need stories." - ) - WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" - WARNING_TEXT_POTENTIAL_PROBLEMS = ( - "NOTE: The following table types might cause problems later:" - ) - NOTE_TEXT_NO_CHANGES = "You have made no changes." - NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - - def make_table_entry( - self, table_name: str, table: Mapping - ) -> TableCmdTableEntry | None: - """ - Make a table entry for the named table. - - :param name: The name of the table. - :param table: The part of ``config.yaml`` corresponding to this table. - :return: The newly-constructed table entry. - """ - if table.get("ignore", False): - return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) - if table.get("vocabulary_table", False): - return TableCmdTableEntry( - table_name, TableType.VOCABULARY, TableType.VOCABULARY - ) - if table.get("primary_private", False): - return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) - if table.get("num_rows_per_pass", 1) == 0: - return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) - return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - ) -> None: - """Initialise a TableCmd.""" - super().__init__(src_dsn, src_schema, metadata, config) - self.set_prompt() - - @property - def table_entries(self) -> list[TableCmdTableEntry]: - """Get the list of table entries.""" - return cast(list[TableCmdTableEntry], self._table_entries) - - def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: - """Get the table entry of the table with the given name.""" - entry = super()._find_entry_by_table_name(table_name) - if entry is None: - return None - return cast(TableCmdTableEntry, entry) - - def set_prompt(self) -> None: - """Set the prompt according to the current table and its type.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) - else: - self.prompt = "(table) " - - def set_type(self, t_type: TableType) -> None: - """Set the type of the current table.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - entry.new_type = t_type - - def _copy_entries(self) -> None: - """Alter the configuration to match the new table entries.""" - for entry in self.table_entries: - if entry.old_type != entry.new_type: - table = self.get_table_config(entry.name) - if ( - entry.old_type == TableType.EMPTY - and table.get("num_rows_per_pass", 1) == 0 - ): - table["num_rows_per_pass"] = 1 - if entry.new_type == TableType.IGNORE: - table["ignore"] = True - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - elif entry.new_type == TableType.VOCABULARY: - table.pop("ignore", None) - table["vocabulary_table"] = True - table.pop("primary_private", None) - elif entry.new_type == TableType.PRIVATE: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table["primary_private"] = True - elif entry.new_type == TableType.EMPTY: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - table["num_rows_per_pass"] = 0 - else: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - self.set_table_config(entry.name, table) - - def _get_referenced_tables(self, from_table_name: str) -> set[str]: - """Get all the tables referenced by this table's foreign keys.""" - from_meta = self.metadata.tables[from_table_name] - return { - fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys - } - - def _sanity_check_failures(self) -> list[tuple[str, str, str]]: - """Find tables that reference each other that should not given their types.""" - failures = [] - for from_entry in self.table_entries: - from_t = from_entry.new_type - if from_t == TableType.VOCABULARY: - referenced = self._get_referenced_tables(from_entry.name) - for ref in referenced: - to_entry = self._find_entry_by_table_name(ref) - if ( - to_entry is not None - and to_entry.new_type != TableType.VOCABULARY - ): - failures.append( - ( - self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - from_entry.name, - to_entry.name, - ) - ) - return failures - - def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: - """Find tables that reference each other that might cause problems given their types.""" - warnings = [] - for from_entry in self.table_entries: - from_t = from_entry.new_type - if from_t in {TableType.GENERATE, TableType.PRIVATE}: - referenced = self._get_referenced_tables(from_entry.name) - for ref in referenced: - to_entry = self._find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type in { - TableType.EMPTY, - TableType.IGNORE, - }: - warnings.append( - ( - self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - from_entry.name, - to_entry.name, - ) - ) - return warnings - - def do_quit(self, _arg: str) -> bool: - """Check the updates, save them if desired and quit the configurer.""" - count = 0 - for entry in self.table_entries: - if entry.old_type != entry.new_type: - count += 1 - self.print( - self.NOTE_TEXT_CHANGING, - entry.name, - entry.old_type.value, - entry.new_type.value, - ) - if count == 0: - self.print(self.NOTE_TEXT_NO_CHANGES) - failures = self._sanity_check_failures() - if failures: - self.print(self.WARNING_TEXT_PROBLEMS_EXIST) - for text, from_t, to_t in failures: - self.print(text, from_t, to_t) - warnings = self._sanity_check_warnings() - if warnings: - self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) - for text, from_t, to_t in warnings: - self.print(text, from_t, to_t) - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, _arg: str) -> None: - """List the tables with their types.""" - for entry in self.table_entries: - old = entry.old_type - new = entry.new_type - becomes = " " if old == new else "->" + TYPE_LETTER[new] - self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) - - def do_next(self, arg: str) -> None: - """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" - if arg: - # Find the index of the table called _arg, if any - index = self.find_entry_index_by_table_name(arg) - if index is None: - self.print(self.ERROR_NO_SUCH_TABLE, arg) - return - self._set_table_index(index) - return - self.next_table(self.INFO_NO_MORE_TABLES) - - def complete_next( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Get the completions for tables and columns.""" - return [ - entry.name for entry in self.table_entries if entry.name.startswith(text) - ] - - def do_previous(self, _arg: str) -> None: - """Go to the previous table.""" - if not self._set_table_index(self.table_index - 1): - self.print(self.ERROR_ALREADY_AT_START) - - def do_ignore(self, _arg: str) -> None: - """Set the current table as ignored, and go to the next table.""" - self.set_type(TableType.IGNORE) - self.print("Table {} set as ignored", self.table_name()) - self.next_table() - - def do_vocabulary(self, _arg: str) -> None: - """Set the current table as a vocabulary table, and go to the next table.""" - self.set_type(TableType.VOCABULARY) - self.print("Table {} set to be a vocabulary table", self.table_name()) - self.next_table() - - def do_private(self, _arg: str) -> None: - """Set the current table as a primary private table (such as the table of patients).""" - self.set_type(TableType.PRIVATE) - self.print("Table {} set to be a primary private table", self.table_name()) - self.next_table() - - def do_generate(self, _arg: str) -> None: - """Set the current table as to be generated, and go to the next table.""" - self.set_type(TableType.GENERATE) - self.print("Table {} generate", self.table_name()) - self.next_table() - - def do_empty(self, _arg: str) -> None: - """Set the current table as empty; no generators will be run for it.""" - self.set_type(TableType.EMPTY) - self.print("Table {} empty", self.table_name()) - self.next_table() - - def do_columns(self, _arg: str) -> None: - """Report the column names and metadata.""" - self.report_columns() - - def do_data(self, arg: str) -> None: - """ - Report some data. - - 'data' = report a random ten lines, - 'data 20' = report a random 20 lines, - 'data 20 ColumnName' = report a random twenty entries from ColumnName, - 'data 20 ColumnName 30' = report a random twenty entries from ColumnName of length at least 30, - """ - args = arg.split() - column = None - number = None - arg_index = 0 - min_length = 0 - table_metadata = self.table_metadata() - if arg_index < len(args) and args[arg_index].isdigit(): - number = int(args[arg_index]) - arg_index += 1 - if arg_index < len(args) and args[arg_index] in table_metadata.columns: - column = args[arg_index] - arg_index += 1 - if arg_index < len(args) and args[arg_index].isdigit(): - min_length = int(args[arg_index]) - arg_index += 1 - if arg_index != len(args): - self.print( - """Did not understand these arguments -The format is 'data [entries] [column-name [minimum-length]]' where [] means optional text. -Type 'columns' to find out valid column names for this table. -Type 'help data' for examples.""" - ) - return - if column is None: - if number is None: - number = 10 - self.print_row_data(number) - else: - if number is None: - number = 48 - self.print_column_data(column, number, min_length) - - def complete_data( - self, text: str, line: str, begidx: int, _endidx: int - ) -> list[str]: - """Get completions for arguments to ``data``.""" - previous_parts = line[: begidx - 1].split() - if len(previous_parts) != 2: - return [] - table_metadata = self.table_metadata() - return [k for k in table_metadata.columns.keys() if k.startswith(text)] - - def print_column_data(self, column: str, count: int, min_length: int) -> None: - """ - Print a sample of data from a certain column of the current table. - - :param column: The name of the column to report on. - :param count: The number of rows to sample. - :param min_length: The minimum length of text to choose from (0 for any text). - """ - where = f"WHERE {column} IS NOT NULL" - if 0 < min_length: - where = "WHERE LENGTH({column}) >= {len}".format( - column=column, - len=min_length, - ) - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - column=column, - count=count, - where=where, - ) - ) - ) - self.columnize([str(x[0]) for x in result.all()]) - - def print_row_data(self, count: int) -> None: - """ - Print a sample or rows from the current table. - - :param count: The number of rows to report. - """ - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - count=count, - ) - ) - ) - if result is None: - self.print("No rows in this table!") - return - self.print_results(result) - - -def update_config_tables( - src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping -) -> Mapping[str, Any]: - """Ask the user to specify what should happen to each table.""" - with TableCmd(src_dsn, src_schema, metadata, config) as tc: - tc.cmdloop() - return tc.config - - -@dataclass -class MissingnessType: - """The functions required for applying missingness.""" - - SAMPLED = "column_presence.sampled" - SAMPLED_QUERY = ( - "SELECT COUNT(*) AS row_count, {result_names} FROM " - "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" - " AS __t GROUP BY {result_names}" - ) - name: str - query: str - comment: str - columns: list[str] - - @classmethod - def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: - """ - Construct a query to make a sampling of the named rows of the table. - - :param table: The name of the table to sample. - :param count: The number of samples to get. - :param column_names: The columns to fetch. - :return: The SQL query to do the sampling. - """ - result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) - column_is_nulls = ", ".join( - ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] - ) - return cls.SAMPLED_QUERY.format( - result_names=result_names, - column_is_nulls=column_is_nulls, - table=table, - count=count, - ) - - -@dataclass -class MissingnessCmdTableEntry(TableEntry): - """Table entry for the missingness command shell.""" - - old_type: MissingnessType - new_type: MissingnessType | None - - -class MissingnessCmd(DbCmd): - """ - Interactive shell for the user to set missingness. - - Can only be used for Missingness Completely At Random. - """ - - intro = "Interactive missingness configuration. Type ? for help.\n" - doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for -the current table. Use commands 'next' and 'previous' to change the -current table. Use 'tables' to list the tables and 'count' to show -how many NULLs exist in each column. Use 'peek' or 'select' to see -data from the database. Use 'quit' to exit this tool.""" - prompt = "(missingness) " - file = None - PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') - - def find_missingness_query( - self, missingness_generator: Mapping - ) -> tuple[str, str] | None: - """Find query and comment from src-stats for the passed missingness generator.""" - kwargs = missingness_generator.get("kwargs", {}) - patterns = kwargs.get("patterns", "") - pattern_match = self.PATTERN_RE.match(patterns) - if pattern_match: - key = pattern_match.group(1) - for src_stat in self.config["src-stats"]: - if src_stat.get("name") == key: - query = src_stat.get("query", None) - if type(query) is not str: - return None - return (query, src_stat.get("comment", "")) - return None - - def make_table_entry( - self, table_name: str, table_config: Mapping - ) -> MissingnessCmdTableEntry | None: - """ - Make a table entry for a particular table. - - :param name: The name of the table to make an entry for. - :param table: The part of ``config.yaml`` relating to this table. - :return: The newly-constructed table entry. - """ - if table_config.get("ignore", False): - return None - if table_config.get("vocabulary_table", False): - return None - if table_config.get("num_rows_per_pass", 1) == 0: - return None - mgs = table_config.get("missingness_generators", []) - old = None - nonnull_columns = self.get_nonnull_columns(table_name) - if not nonnull_columns: - return None - if not mgs: - old = MissingnessType( - name="none", - query="", - comment="", - columns=[], - ) - elif len(mgs) == 1: - mg = mgs[0] - mg_name = mg.get("name", None) - if isinstance(mg_name, str): - query_comment = self.find_missingness_query(mg) - if query_comment is not None: - (query, comment) = query_comment - old = MissingnessType( - name=mg_name, - query=query, - comment=comment, - columns=mg.get("columns_assigned", []), - ) - if old is None: - return None - return MissingnessCmdTableEntry( - name=table_name, - old_type=old, - new_type=old, - ) - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping, - ): - """ - Initialise a MissingnessCmd. - - :param src_dsn: connection string for the source database. - :param src_schema: schema name for the source database. - :param metadata: SQLAlchemy metadata for the source database. - :param config: Configuration from the ``config.yaml`` file. - """ - super().__init__(src_dsn, src_schema, metadata, config) - self.set_prompt() - - @property - def table_entries(self) -> list[MissingnessCmdTableEntry]: - """Get the table entries list.""" - return cast(list[MissingnessCmdTableEntry], self._table_entries) - - def _find_entry_by_table_name( - self, table_name: str - ) -> MissingnessCmdTableEntry | None: - """Find the table entry given the table name.""" - entry = super()._find_entry_by_table_name(table_name) - if entry is None: - return None - return cast(MissingnessCmdTableEntry, entry) - - def set_prompt(self) -> None: - """Set the prompt according to the current table and missingness.""" - if self.table_index < len(self.table_entries): - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - nt = entry.new_type - if nt is None: - self.prompt = f"(missingness for {entry.name}) " - else: - self.prompt = f"(missingness for {entry.name}: {nt.name}) " - else: - self.prompt = "(missingness) " - - def set_type(self, t_type: MissingnessType) -> None: - """Set the missingness of the current table.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - entry.new_type = t_type - - def _copy_entries(self) -> None: - """Set the new missingness into the configuration.""" - src_stats = self._remove_prefix_src_stats("missing_auto__") - for entry in self.table_entries: - table = self.get_table_config(entry.name) - if entry.new_type is None or entry.new_type.name == "none": - table.pop("missingness_generators", None) - else: - src_stat_key = f"missing_auto__{entry.name}__0" - table["missingness_generators"] = [ - { - "name": entry.new_type.name, - "kwargs": { - "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' - }, - "columns": entry.new_type.columns, - } - ] - src_stats.append( - { - "name": src_stat_key, - "query": entry.new_type.query, - "comments": [] - if entry.new_type.comment is None - else [entry.new_type.comment], - } - ) - self.set_table_config(entry.name, table) - - def do_quit(self, _arg: str) -> bool: - """Check the updates, save them if desired and quit the configurer.""" - count = 0 - for entry in self.table_entries: - if entry.old_type != entry.new_type: - count += 1 - if entry.old_type is None: - self.print( - "Putting generator {0} on table {1}", - entry.name, - entry.new_type.name, - ) - elif entry.new_type is None: - self.print( - "Deleting generator {1} from table {0}", - entry.name, - entry.old_type.name, - ) - else: - self.print( - "Changing {0} from {1} to {2}", - entry.name, - entry.old_type.name, - entry.new_type.name, - ) - if count == 0: - self.print("You have made no changes.") - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, _arg: str) -> None: - """List the tables with their types.""" - for entry in self.table_entries: - old = "-" if entry.old_type is None else entry.old_type.name - new = "-" if entry.new_type is None else entry.new_type.name - desc = new if old == new else f"{old}->{new}" - self.print("{0} {1}", entry.name, desc) - - def do_next(self, arg: str) -> None: - """ - Go to the next table, or a specified table. - - 'next' = go to the next table, 'next tablename' = go to table 'tablename' - """ - if arg: - # Find the index of the table called _arg, if any - index = next( - (i for i, entry in enumerate(self.table_entries) if entry.name == arg), - None, - ) - if index is None: - self.print(self.ERROR_NO_SUCH_TABLE, arg) - return - self._set_table_index(index) - return - self.next_table(self.INFO_NO_MORE_TABLES) - - def complete_next( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Get completions for tables and columns.""" - return [ - entry.name for entry in self.table_entries if entry.name.startswith(text) - ] - - def do_previous(self, _arg: str) -> None: - """Go to the previous table.""" - if not self._set_table_index(self.table_index - 1): - self.print(self.ERROR_ALREADY_AT_START) - - def _set_type(self, name: str, query: str, comment: str) -> None: - """Set the current table entry's query.""" - if len(self.table_entries) <= self.table_index: - return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - entry.new_type = MissingnessType( - name=name, - query=query, - comment=comment, - columns=self.get_nonnull_columns(entry.name), - ) - - def _set_none(self) -> None: - """Set the current table to have no missingness applied.""" - if len(self.table_entries) <= self.table_index: - return - self.table_entries[self.table_index].new_type = None - - def do_sampled(self, arg: str) -> None: - """ - Set the current table missingness as 'sampled', and go to the next table. - - 'sampled 3000' means sample 3000 rows at random and choose the - missingness to be the same as one of those 3000 at random. - 'sampled' means the same, but with a default number of rows sampled (1000). - """ - if len(self.table_entries) <= self.table_index: - self.print("Error! not on a table") - return - entry = self.table_entries[self.table_index] - if arg == "": - count = 1000 - elif arg.isdecimal(): - count = int(arg) - else: - self.print( - "Error: sampled can be used alone or with an integer argument. {0} is not permitted", - arg, - ) - return - self._set_type( - MissingnessType.SAMPLED, - MissingnessType.sampled_query( - entry.name, - count, - self.get_nonnull_columns(entry.name), - ), - ( - "The missingness patterns and how often they appear in a" - f" sample of {count} from table {entry.name}" - ), - ) - self.print("Table {} set to sampled missingness", self.table_name()) - self.next_table() - - def do_none(self, _arg: str) -> None: - """Set the current table to have no missingness, and go to the next table.""" - self._set_none() - self.print("Table {} set to have no missingness", self.table_name()) - self.next_table() - - -def update_missingness( - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], -) -> Mapping[str, Any]: - """ - Ask the user to update the missingness information in ``config.yaml``. - - :param src_dsn: The connection string for the source database. - :param src_schema: The name of the source database schema (or None - for the default). - :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. - :param config: The starting configuration, - :return: The updated configuration. - """ - with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: - mc.cmdloop() - return mc.config - - -@dataclass -class GeneratorInfo: - """A generator and the columns it assigns to.""" - - columns: list[str] - gen: Generator | None - - -@dataclass -class GeneratorCmdTableEntry(TableEntry): - """ - List of generators set for a table. - - Includes the original setting and the currently configured - generators. - """ - - old_generators: list[GeneratorInfo] - new_generators: list[GeneratorInfo] - - -class GeneratorCmd(DbCmd): - """Interactive command shell for setting generators.""" - - intro = "Interactive generator configuration. Type ? for help.\n" - doc_leader = """Use command 'propose' for a list of generators applicable to the -current column, then command 'compare' to see how these perform -against the source data, then command 'set' to choose your favourite. -Use 'unset' to remove the column's generator. Use commands 'next' and -'previous' to change which column we are examining. Use 'info' -for useful information about the current column. Use 'tables' and -'list' to see available tables and columns. Use 'columns' to see -information about the columns in the current table. Use 'peek', -'count' or 'select' to fetch data from the source database. Use -'quit' to exit this program.""" - prompt = "(generatorconf) " - file = None - - PROPOSE_SOURCE_SAMPLE_TEXT = "Sample of actual source data: {0}..." - PROPOSE_SOURCE_EMPTY_TEXT = "Source database has no data in this column." - PROPOSE_GENERATOR_SAMPLE_TEXT = "{index}. {name}: {fit} {sample} ..." - PRIMARY_PRIVATE_TEXT = "Primary Private" - SECONDARY_PRIVATE_TEXT = "Secondary Private on columns {0}" - NOT_PRIVATE_TEXT = "Not private" - ERROR_NO_SUCH_TABLE = "No such (non-vocabulary, non-ignored) table name {0}" - ERROR_NO_SUCH_COLUMN = "No such column {0} in this table" - ERROR_COLUMN_ALREADY_MERGED = "Column {0} is already merged" - ERROR_COLUMN_ALREADY_UNMERGED = "Column {0} is not merged" - ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" - PROPOSE_NOTHING = "No proposed generators, sorry." - - SRC_STAT_RE = re.compile( - r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' - ) - - def make_table_entry( - self, table_name: str, table_config: Mapping - ) -> GeneratorCmdTableEntry | None: - """ - Make a table entry. - - :param table_name: The name of the table. - :param table: The portion of the ``config.yaml`` file describing this table. - :return: The newly constructed table entry, or None if this table is to be ignored. - """ - if table_config.get("ignore", False): - return None - if table_config.get("vocabulary_table", False): - return None - if table_config.get("num_rows_per_pass", 1) == 0: - return None - metadata_table = self.metadata.tables[table_name] - columns = [str(colname) for colname in metadata_table.columns.keys()] - column_set = frozenset(columns) - columns_assigned_so_far: set[str] = set() - - new_generator_infos: list[GeneratorInfo] = [] - old_generator_infos: list[GeneratorInfo] = [] - for rg in table_config.get("row_generators", []): - gen_name = rg.get("name", None) - if gen_name: - ca = rg.get("columns_assigned", []) - collist: list[str] = ( - [ca] if isinstance(ca, str) else [str(c) for c in ca] - ) - colset: set[str] = set(collist) - for unknown in colset - column_set: - logger.warning( - "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, - gen_name, - unknown, - ) - for mult in columns_assigned_so_far & colset: - logger.warning( - "table '%s' has column '%s' assigned to multiple times", - table_name, - mult, - ) - actual_collist = [c for c in collist if c in columns] - if actual_collist: - gen = PredefinedGenerator(table_name, rg, self.config) - new_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) - ) - old_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) - ) - columns_assigned_so_far |= colset - for colname in columns: - if colname not in columns_assigned_so_far: - new_generator_infos.append( - GeneratorInfo( - columns=[colname], - gen=None, - ) - ) - if len(new_generator_infos) == 0: - return None - return GeneratorCmdTableEntry( - name=table_name, - old_generators=old_generator_infos, - new_generators=new_generator_infos, - ) - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - ) -> None: - """ - Initialise a ``GeneratorCmd``. - - :param src_dsn: connection address for source database - :param src_schema: database schema name - :param metadata: SQLAlchemy metadata for the source database - :param config: Configuration loaded from ``config.yaml`` - """ - super().__init__(src_dsn, src_schema, metadata, config) - self.generators: list[Generator] | None = None - self.generator_index = 0 - self.generators_valid_columns: Optional[tuple[int, list[str]]] = None - self.set_prompt() - - @property - def table_entries(self) -> list[GeneratorCmdTableEntry]: - """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" - return cast(list[GeneratorCmdTableEntry], self._table_entries) - - def _find_entry_by_table_name( - self, table_name: str - ) -> GeneratorCmdTableEntry | None: - """ - Find the table entry by name. - - :param table_name: The name of the table to find. - :return: The table entry, or None if no such table name exists. - """ - entry = super()._find_entry_by_table_name(table_name) - if entry is None: - return None - return cast(GeneratorCmdTableEntry, entry) - - def _set_table_index(self, index: int) -> bool: - """ - Move to a new table. - - :param index: table index to move to. - """ - ret = super()._set_table_index(index) - if ret: - self.generator_index = 0 - self.set_prompt() - return ret - - def _previous_table(self) -> bool: - """ - Move to the table before the current one. - - :return: True if there is a previous table to go to. - """ - ret = self._set_table_index(self.table_index - 1) - if ret: - table = self.get_table() - if table is None: - self.print( - "Internal error! table {0} does not have any generators!", - self.table_index, - ) - return False - self.generator_index = len(table.new_generators) - 1 - else: - self.print(self.ERROR_ALREADY_AT_START) - return ret - - def get_table(self) -> GeneratorCmdTableEntry | None: - """Get the current table entry.""" - if self.table_index < len(self.table_entries): - return self.table_entries[self.table_index] - return None - - def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: - """Get a pair; the table name then the generator information.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - if self.generator_index < len(entry.new_generators): - return (entry.name, entry.new_generators[self.generator_index]) - return (entry.name, None) - return (None, None) - - def _get_column_names(self) -> list[str]: - """Get the (unqualified) names for all the current columns.""" - (_, generator_info) = self._get_table_and_generator() - return generator_info.columns if generator_info else [] - - def _column_metadata(self) -> list[Column]: - """Get the metadata for all the current columns.""" - table = self.table_metadata() - if table is None: - return [] - return [table.columns[name] for name in self._get_column_names()] - - def set_prompt(self) -> None: - """Set the prompt according to the current table, column and generator.""" - (table_name, gen_info) = self._get_table_and_generator() - if table_name is None: - self.prompt = "(generators) " - return - if gen_info is None: - self.prompt = f"({table_name}) " - return - table = self.table_metadata() - columns = [ - c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns - ] - gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" - self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - - def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: - """ - Remove all automatic source stats. - - We assume every source stats query whose name begins with ``auto__` - :return: The new ``src_stats`` configuration. - """ - return self._remove_prefix_src_stats("auto__") - - def _copy_entries(self) -> None: - """Set generator and query information in the configuration.""" - src_stats = self._remove_auto_src_stats() - for entry in self.table_entries: - rgs = [] - new_gens: list[Generator] = [] - for generator in entry.new_generators: - if generator.gen is not None: - new_gens.append(generator.gen) - cqs = generator.gen.custom_queries() - for cq_key, cq in cqs.items(): - src_stats.append( - { - "name": cq_key, - "query": cq["query"], - "comments": [cq["comment"]] - if "comment" in cq and cq["comment"] - else [], - } - ) - rg: dict[str, Any] = { - "name": generator.gen.function_name(), - "columns_assigned": generator.columns, - } - kwn = generator.gen.nominal_kwargs() - if kwn: - rg["kwargs"] = kwn - rgs.append(rg) - aq = self._get_aggregate_query(new_gens, entry.name) - if aq: - src_stats.append( - { - "name": f"auto__{entry.name}", - "query": aq, - "comments": [ - q["comment"] - for gen in new_gens - for q in gen.select_aggregate_clauses().values() - if "comment" in q and q["comment"] is not None - ], - } - ) - table_config = self.get_table_config(entry.name) - if rgs: - table_config["row_generators"] = rgs - elif "row_generators" in table_config: - del table_config["row_generators"] - self.set_table_config(entry.name, table_config) - self.config["src-stats"] = src_stats - - def _find_old_generator( - self, entry: GeneratorCmdTableEntry, columns: Iterable[str] - ) -> Generator | None: - """Find any generator that previously assigned to these exact same columns.""" - fc = frozenset(columns) - for gen in entry.old_generators: - if frozenset(gen.columns) == fc: - return gen.gen - return None - - def do_quit(self, arg: str) -> bool: - """Check the updates, save them if desired and quit the configurer.""" - count = 0 - for entry in self.table_entries: - header_shown = False - g_entry = cast(GeneratorCmdTableEntry, entry) - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - new_gen = None if gen is None else gen.gen - if old_gen != new_gen: - if not header_shown: - header_shown = True - self.print("Table {0}:", entry.name) - count += 1 - self.print( - "...changing {0} from {1} to {2}", - ", ".join(gen.columns), - old_gen.name() if old_gen else "nothing", - gen.gen.name() if gen.gen else "nothing", - ) - if count == 0: - self.print("You have made no changes.") - if arg in {"yes", "no"}: - reply = arg - else: - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, _arg: str) -> None: - """List the tables.""" - for t_entry in self.table_entries: - entry = cast(GeneratorCmdTableEntry, t_entry) - gen_count = len(entry.new_generators) - how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" - self.print("{0} ({1})", entry.name, how_many) - - def do_list(self, _arg: str) -> None: - """List the generators in the current table.""" - if len(self.table_entries) <= self.table_index: - self.print("Error: no table {0}", self.table_index) - return - g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) - table = self.table_metadata() - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - old = "" if old_gen is None else old_gen.name() - if old_gen == gen.gen: - becomes = "" - if old == "": - old = "(not set)" - elif gen.gen is None: - becomes = "(delete)" - else: - becomes = f"->{gen.gen.name()}" - primary = "" - if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: - primary = "[primary-key]" - self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) - - def do_columns(self, _arg: str) -> None: - """Report the column names and metadata.""" - self.report_columns() - - def do_info(self, _arg: str) -> None: - """Show information about the current column.""" - for cm in self._column_metadata(): - self.print( - "Column {0} in table {1} has type {2} ({3}).", - cm.name, - cm.table.name, - str(cm.type), - "nullable" if cm.nullable else "not nullable", - ) - if cm.primary_key: - self.print( - "It is a primary key, which usually does not need a generator (it will auto-increment)" - ) - if cm.foreign_keys: - fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] - self.print( - "It is a foreign key referencing column {0}", ", ".join(fk_names) - ) - if len(fk_names) == 1 and not cm.primary_key: - self.print( - "You do not need a generator if you just want a uniform choice over the referenced table's rows" - ) - - def _get_table_index(self, table_name: str) -> int | None: - """Get the index of the named table in the table entries list.""" - for n, entry in enumerate(self.table_entries): - if entry.name == table_name: - return n - return None - - def _get_generator_index(self, table_index: int, column_name: str) -> int | None: - """ - Get the index number of a column within the list of generators in this table. - - :param table_index: The index of the table in which to search. - :param column_name: The name of the column to search for. - :return: The index in the ``new_generators`` attribute of the table entry - containing the specified column, or None if this does not exist. - """ - entry = self.table_entries[table_index] - for n, gen in enumerate(entry.new_generators): - if column_name in gen.columns: - return n - return None - - def go_to(self, target: str) -> bool: - """ - Go to a particular column. - - :return: True on success. - """ - parts = target.split(".", 1) - table_index = self._get_table_index(parts[0]) - if table_index is None: - if len(parts) == 1: - gen_index = self._get_generator_index(self.table_index, parts[0]) - if gen_index is not None: - self.generator_index = gen_index - self.set_prompt() - return True - self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, parts[0]) - return False - gen_index = None - if 1 < len(parts) and parts[1]: - gen_index = self._get_generator_index(table_index, parts[1]) - if gen_index is None: - self.print("we cannot set the generator for column {0}", parts[1]) - return False - self._set_table_index(table_index) - if gen_index is not None: - self.generator_index = gen_index - self.set_prompt() - return True - - def do_next(self, arg: str) -> None: - """ - Go to the next generator. or a specified generator. - - Go to a named table: 'next tablename', - go to a column: 'next tablename.columnname', - or go to a column within this table: 'next columnname'. - """ - if arg: - self.go_to(arg) - else: - self._go_next() - - def do_n(self, arg: str) -> None: - """Go to the next generator, or a specified generator.""" - self.do_next(arg) - - def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: - """Complete the ``n`` command's arguments.""" - return self.complete_next(text, line, begidx, endidx) - - def _go_next(self) -> None: - """Go to the next column.""" - table = self.get_table() - if table is None: - self.print("No more tables") - return - next_gi = self.generator_index + 1 - if next_gi == len(table.new_generators): - self.next_table(self.INFO_NO_MORE_TABLES) - return - self.generator_index = next_gi - self.set_prompt() - - def complete_next( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Completions for the arguments of the ``next`` command.""" - parts = text.split(".", 1) - first_part = parts[0] - if 1 < len(parts): - column_name = parts[1] - table_index = self._get_table_index(first_part) - if table_index is None: - return [] - table_entry = self.table_entries[table_index] - return [ - f"{first_part}.{column}" - for gen in table_entry.new_generators - for column in gen.columns - if column.startswith(column_name) - ] - table_names = [ - entry.name - for entry in self.table_entries - if entry.name.startswith(first_part) - ] - if first_part in table_names: - table_names.append(f"{first_part}.") - current_table = self.get_table() - if current_table: - column_names = [ - col - for gen in current_table.new_generators - for col in gen.columns - if col.startswith(first_part) - ] - else: - column_names = [] - return table_names + column_names - - def do_previous(self, _arg: str) -> None: - """Go to the previous generator.""" - if self.generator_index == 0: - self._previous_table() - else: - self.generator_index -= 1 - self.set_prompt() - - def do_b(self, arg: str) -> None: - """Synonym for previous.""" - self.do_previous(arg) - - def _generators_valid(self) -> bool: - """Test if ``self.generators`` is still correct for the current columns.""" - return self.generators_valid_columns == ( - self.table_index, - self._get_column_names(), - ) - - def _get_generator_proposals(self) -> list[Generator]: - """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" - if not self._generators_valid(): - self.generators = None - if self.generators is None: - columns = self._column_metadata() - gens = everything_factory().get_generators(columns, self.sync_engine) - sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) - self.generators = sorted_gens - self.generators_valid_columns = ( - self.table_index, - self._get_column_names().copy(), - ) - return self.generators - - def _print_privacy(self) -> None: - """Print the privacy status of the current table.""" - table = self.table_metadata() - if table is None: - return - if table_is_private(self.config, table.name): - self.print(self.PRIMARY_PRIVATE_TEXT) - return - pfks = primary_private_fks(self.config, table) - if not pfks: - self.print(self.NOT_PRIVATE_TEXT) - return - self.print(self.SECONDARY_PRIVATE_TEXT, pfks) - - def do_compare(self, arg: str) -> None: - """ - Compare the real data with some generators. - - 'compare': just look at some source data from this column. - 'compare 5 6 10': compare a sample of the source data with a sample - from generators 5, 6 and 10. You can find out which numbers - correspond to which generators using the 'propose' command. - """ - self._print_privacy() - args = arg.split() - limit = 20 - comparison = { - "source": [ - x[0] if len(x) == 1 else ", ".join(x) - for x in self._get_column_data(limit, to_str=str) - ] - } - gens: list[Generator] = self._get_generator_proposals() - table_name = self.table_name() - for argument in args: - if argument.isdigit(): - n = int(argument) - if 0 < n <= len(gens): - gen = gens[n - 1] - comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) - self._print_values_queried(table_name, n, gen) - self.print_table_by_columns(comparison) - - def do_c(self, arg: str) -> None: - """Synonym for compare.""" - self.do_compare(arg) - - def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: - """ - Print the values queried from the database for this generator. - - :param table_name: The name of the table the generator applies to. - :param n: A number to print at the start of the output. - :param gen: The generator to report. - """ - if not gen.select_aggregate_clauses() and not gen.custom_queries(): - self.print( - "{0}. {1} requires no data from the source database.", - n, - gen.name(), - ) - else: - self.print( - "{0}. {1} requires the following data from the source database:", - n, - gen.name(), - ) - self._print_select_aggregate_query(table_name, gen) - self._print_custom_queries(gen) - - def _print_custom_queries(self, gen: Generator) -> None: - """ - Print all the custom queries and all the values they get in this case. - - :param gen: The generator to print the custom queries for. - """ - cqs = gen.custom_queries() - if not cqs: - return - cq_key2args: dict[str, Any] = {} - nominal = gen.nominal_kwargs() - actual = gen.actual_kwargs() - self._get_custom_queries_from( - cq_key2args, - nominal, - actual, - ) - for cq_key, cq in cqs.items(): - self.print( - "{0}; providing the following values: {1}", - cq["query"], - cq_key2args[cq_key], - ) - - def _get_custom_queries_from( - self, out: dict[str, Any], nominal: Any, actual: Any - ) -> None: - if isinstance(nominal, str): - src_stat_groups = self.SRC_STAT_RE.search(nominal) - # Do we have a SRC_STAT reference? - if src_stat_groups: - # Get its name - cq_key = src_stat_groups.group(1) - # Are we pulling a specific part of this result? - sub = src_stat_groups.group(3) - if sub: - actual = {sub: actual} - else: - out[cq_key] = actual - elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): - for i in range(min(len(nominal), len(actual))): - self._get_custom_queries_from(out, nominal[i], actual[i]) - elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): - for k, v in nominal.items(): - if k in actual: - self._get_custom_queries_from(out, v, actual[k]) - - def _get_aggregate_query( - self, gens: list[Generator], table_name: str - ) -> str | None: - clauses = [ - f'{q["clause"]} AS {n}' - for gen in gens - for n, q in or_default(gen.select_aggregate_clauses(), {}).items() - ] - if not clauses: - return None - return f"SELECT {', '.join(clauses)} FROM {table_name}" - - def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: - """ - Print the select aggregate query and all the values it gets in this case. - - This is not the entire query that will be executed, but only the part of it - that is required by a certain generator. - :param table_name: The table name. - :param gen: The generator to limit the aggregate query to. - """ - sacs = gen.select_aggregate_clauses() - if not sacs: - return - kwa = gen.actual_kwargs() - vals = [] - src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} - for n in sacs.keys(): - src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' - if src_stat in src_stat2kwarg: - ak = src_stat2kwarg[src_stat] - if ak in kwa: - vals.append(kwa[ak]) - else: - logger.warning( - "actual_kwargs for %s does not report %s", gen.name(), ak - ) - else: - logger.warning( - 'nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', - gen.name(), - table_name, - n, - ) - select_q = self._get_aggregate_query([gen], table_name) - self.print("{0}; providing the following values: {1}", select_q, vals) - - def _get_column_data( - self, count: int, to_str: Callable[[Any], str] = repr - ) -> list[list[str]]: - columns = self._get_column_names() - columns_string = ", ".join(columns) - pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" - ) - ) - return [[to_str(x) for x in xs] for xs in result.all()] - - def do_propose(self, _arg: str) -> None: - """ - Display a list of possible generators for this column. - - They will be listed in order of fit, the most likely matches first. - The results can be compared (against a sample of the real data in - the column and against each other) with the 'compare' command. - """ - limit = 5 - gens = self._get_generator_proposals() - sample = self._get_column_data(limit) - if sample: - rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] - self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) - else: - self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) - if not gens: - self.print(self.PROPOSE_NOTHING) - for index, gen in enumerate(gens): - fit = gen.fit(-1) - if fit == -1: - fit_s = "(no fit)" - elif fit < 100: - fit_s = f"(fit: {fit:.3g})" - else: - fit_s = f"(fit: {fit:.0f})" - self.print( - self.PROPOSE_GENERATOR_SAMPLE_TEXT, - index=index + 1, - name=gen.name(), - fit=fit_s, - sample="; ".join(map(repr, gen.generate_data(limit))), - ) - - def do_p(self, arg: str) -> None: - """Synonym for propose.""" - self.do_propose(arg) - - def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: - """Find a generator by name from the list of proposals.""" - for gen in self._get_generator_proposals(): - if gen.name() == gen_name: - return gen - return None - - def do_set(self, arg: str) -> None: - """Set one of the proposals as a generator.""" - if arg.isdigit() and not self._generators_valid(): - self.print("Please run 'propose' before 'set '") - return - gens = self._get_generator_proposals() - new_gen: Generator | None - if arg.isdigit(): - index = int(arg) - if index < 1: - self.print("set's integer argument must be at least 1") - return - if len(gens) < index: - self.print( - "There are currently only {0} generators proposed, please select one of them.", - len(gens), - ) - return - new_gen = gens[index - 1] - else: - new_gen = self.get_proposed_generator_by_name(arg) - if new_gen is None: - self.print("'{0}' is not an appropriate generator for this column", arg) - return - self.set_generator(new_gen) - self._go_next() - - def set_generator(self, gen: Generator | None) -> None: - """Set the current column's generator.""" - (table, gen_info) = self._get_table_and_generator() - if table is None: - self.print("Error: no table") - return - if gen_info is None: - self.print("Error: no column") - return - gen_info.gen = gen - - def do_s(self, arg: str) -> None: - """Synonym for set.""" - self.do_set(arg) - - def do_unset(self, _arg: str) -> None: - """Remove any generator set for this column.""" - self.set_generator(None) - self._go_next() - - def do_merge(self, arg: str) -> None: - """ - Add this column(s) to the specified column(s). - - After this, one generator will cover them all. - """ - cols = arg.split() - if not cols: - self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - self.print(self.ERROR_NO_SUCH_TABLE) - return - cols_available = functools.reduce( - lambda x, y: x | y, - [frozenset(gen.columns) for gen in table_entry.new_generators], - ) - cols_to_merge = frozenset(cols) - unknown_cols = cols_to_merge - cols_available - if unknown_cols: - for uc in unknown_cols: - self.print(self.ERROR_NO_SUCH_COLUMN, uc) - return - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - stated_current_columns = cols_to_merge & current_columns - if stated_current_columns: - for c in stated_current_columns: - self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) - return - # Remove cols_to_merge from each generator - new_new_generators: list[GeneratorInfo] = [] - for gen in table_entry.new_generators: - if gen is gen_info: - # Add columns to this generator - self.generator_index = len(new_new_generators) - new_new_generators.append( - GeneratorInfo( - columns=gen.columns + cols, - gen=None, - ) - ) - else: - # Remove columns if applicable - new_columns = [c for c in gen.columns if c not in cols_to_merge] - is_changed = len(new_columns) != len(gen.columns) - if new_columns: - # We have not removed this generator completely - new_new_generators.append( - GeneratorInfo( - columns=new_columns, - gen=None if is_changed else gen.gen, - ) - ) - table_entry.new_generators = new_new_generators - self.set_prompt() - - def complete_merge( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Complete column names.""" - last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - return [] - return [ - column - for i, gen in enumerate(table_entry.new_generators) - if i != self.generator_index - for column in gen.columns - if column.startswith(last_arg) - ] - - def do_unmerge(self, arg: str) -> None: - """Remove this column(s) from this generator, make them a separate generator.""" - cols = arg.split() - if not cols: - self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - self.print(self.ERROR_NO_SUCH_TABLE) - return - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - cols_to_unmerge = frozenset(cols) - unknown_cols = cols_to_unmerge - current_columns - if unknown_cols: - for uc in unknown_cols: - self.print(self.ERROR_NO_SUCH_COLUMN, uc) - return - stated_unmerged_columns = cols_to_unmerge - current_columns - if stated_unmerged_columns: - for c in stated_unmerged_columns: - self.print(self.ERROR_COLUMN_ALREADY_UNMERGED, c) - return - if cols_to_unmerge == current_columns: - self.print(self.ERROR_CANNOT_UNMERGE_ALL) - return - # Remove unmerged columns - for um in cols_to_unmerge: - gen_info.columns.remove(um) - # The existing generator will not work - gen_info.gen = None - # And put them into a new (empty) generator - table_entry.new_generators.insert( - self.generator_index + 1, - GeneratorInfo( - columns=cols, - gen=None, - ), - ) - self.set_prompt() - - def complete_unmerge( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Complete column names to unmerge.""" - last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - return [] - return [ - column - for column in table_entry.new_generators[self.generator_index].columns - if column.startswith(last_arg) - ] - - -def update_config_generators( - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - spec_path: Path | None, -) -> Mapping[str, Any]: - """ - Update configuration with the specification from a CSV file. - - The specification is a headerless CSV file with columns: Table name, - Column name (or space-separated list of column names), Generator - name required, Second choice generator name, Third choice generator - name, etcetera. - :param src_dsn: Address of the source database - :param src_schema: Name of the source database schema to read from - :param metadata: SQLAlchemy representation of the source database - :param config: Existing configuration (will be destructively updated) - :param spec_path: The path of the CSV file containing the specification - :return: Updated configuration. - """ - with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: - if spec_path is None: - gc.cmdloop() - return gc.config - spec = spec_path.open() - line_no = 0 - for line in csv.reader(spec): - line_no += 1 - if line: - if len(line) != 3: - logger.error( - "line %d of file %s does not have three values", - line_no, - spec_path, - ) - if gc.go_to(f"{line[0]}.{line[1]}"): - gc.do_set(line[2]) - gc.do_quit("yes") - return gc.config diff --git a/datafaker/interactive/__init__.py b/datafaker/interactive/__init__.py new file mode 100644 index 00000000..72ec00e2 --- /dev/null +++ b/datafaker/interactive/__init__.py @@ -0,0 +1,95 @@ +"""Interactive configuration commands.""" +import csv +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Any + +from sqlalchemy import MetaData + +from datafaker.interactive.table import TableCmd +from datafaker.interactive.generators import GeneratorCmd +from datafaker.interactive.missingness import MissingnessCmd +from datafaker.utils import logger + +# Monkey patch pyreadline3 v3.5 so that it works with Python 3.13 +# Windows users can install pyreadline3 to get tab completion working. +# See https://github.com/pyreadline3/pyreadline3/issues/37 +try: + import readline + + if not hasattr(readline, "backend"): + setattr(readline, "backend", "readline") +except: + pass + + +def update_config_tables( + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping +) -> Mapping[str, Any]: + """Ask the user to specify what should happen to each table.""" + with TableCmd(src_dsn, src_schema, metadata, config) as tc: + tc.cmdloop() + return tc.config + + +def update_missingness( + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], +) -> Mapping[str, Any]: + """ + Ask the user to update the missingness information in ``config.yaml``. + + :param src_dsn: The connection string for the source database. + :param src_schema: The name of the source database schema (or None + for the default). + :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. + :param config: The starting configuration, + :return: The updated configuration. + """ + with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: + mc.cmdloop() + return mc.config + + +def update_config_generators( + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + spec_path: Path | None, +) -> Mapping[str, Any]: + """ + Update configuration with the specification from a CSV file. + + The specification is a headerless CSV file with columns: Table name, + Column name (or space-separated list of column names), Generator + name required, Second choice generator name, Third choice generator + name, etcetera. + :param src_dsn: Address of the source database + :param src_schema: Name of the source database schema to read from + :param metadata: SQLAlchemy representation of the source database + :param config: Existing configuration (will be destructively updated) + :param spec_path: The path of the CSV file containing the specification + :return: Updated configuration. + """ + with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: + if spec_path is None: + gc.cmdloop() + return gc.config + spec = spec_path.open() + line_no = 0 + for line in csv.reader(spec): + line_no += 1 + if line: + if len(line) != 3: + logger.error( + "line %d of file %s does not have three values", + line_no, + spec_path, + ) + if gc.go_to(f"{line[0]}.{line[1]}"): + gc.do_set(line[2]) + gc.do_quit("yes") + return gc.config diff --git a/datafaker/interactive/base.py b/datafaker/interactive/base.py new file mode 100644 index 00000000..51793fe0 --- /dev/null +++ b/datafaker/interactive/base.py @@ -0,0 +1,404 @@ +"""Base configuration command shells.""" +import cmd +from abc import ABC, abstractmethod +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass +from enum import Enum +from types import TracebackType +from typing import Any, Optional, Type + +import sqlalchemy +from prettytable import PrettyTable +from sqlalchemy import Engine, ForeignKey, MetaData, Table +from typing_extensions import Self + +from datafaker.utils import ( + T, + create_db_engine, + fk_refers_to_ignored_table, + get_sync_engine, +) + + +def or_default(v: T | None, d: T) -> T: + """Return v if it isn't None, otherwise d.""" + return d if v is None else v + + +class TableType(Enum): + """Types of table to be configured.""" + + GENERATE = "generate" + IGNORE = "ignore" + VOCABULARY = "vocabulary" + PRIVATE = "private" + EMPTY = "empty" + + +TYPE_LETTER = { + TableType.GENERATE: "G", + TableType.IGNORE: "I", + TableType.VOCABULARY: "V", + TableType.PRIVATE: "P", + TableType.EMPTY: "e", +} + +TYPE_PROMPT = { + TableType.GENERATE: "(table: {}) ", + TableType.IGNORE: "(table: {} (ignore)) ", + TableType.VOCABULARY: "(table: {} (vocab)) ", + TableType.PRIVATE: "(table: {} (private)) ", + TableType.EMPTY: "(table: {} (empty))", +} + + +@dataclass +class TableEntry: + """Base class for table entries for interactive commands.""" + + name: str # name of the table + + +class AskSaveCmd(cmd.Cmd): + """Interactive shell for whether to save and quit.""" + + intro = "Do you want to save this configuration?" + prompt = "(yes/no/cancel) " + file = None + + def __init__(self) -> None: + """Initialise a save command.""" + super().__init__() + self.result = "" + + def do_yes(self, _arg: str) -> bool: + """Save the new config.yaml.""" + self.result = "yes" + return True + + def do_no(self, _arg: str) -> bool: + """Exit without saving.""" + self.result = "no" + return True + + def do_cancel(self, _arg: str) -> bool: + """Do not exit.""" + self.result = "cancel" + return True + + +def fk_column_name(fk: ForeignKey) -> str: + """Display name for a foreign key.""" + if fk_refers_to_ignored_table(fk): + return f"{fk.target_fullname} (ignored)" + return str(fk.target_fullname) + + +class DbCmd(ABC, cmd.Cmd): + """Base class for interactive configuration commands.""" + + INFO_NO_MORE_TABLES = "There are no more tables" + ERROR_ALREADY_AT_START = "Error: Already at the start" + ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" + ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" + ROW_COUNT_MSG = "Total row count: {}" + + @abstractmethod + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableEntry | None: + """ + Make a table entry suitable for this interactive command. + + :param name: The name of the table to make an entry for. + :param table_config: The part of the ``config.yaml`` referring to this table. + :return: The table entry or None if this table should not be interacted with. + """ + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ): + """Initialise a DbCmd.""" + super().__init__() + self.config: MutableMapping[str, Any] = config + self.metadata = metadata + self._table_entries: list[TableEntry] = [] + tables_config: MutableMapping = config.get("tables", {}) + if not isinstance(tables_config, MutableMapping): + tables_config = {} + for name in metadata.tables.keys(): + table_config = tables_config.get(name, {}) + if not isinstance(table_config, MutableMapping): + table_config = {} + entry = self.make_table_entry(name, table_config) + if entry is not None: + self._table_entries.append(entry) + self.table_index = 0 + self.engine = create_db_engine(src_dsn, schema_name=src_schema) + + @property + def sync_engine(self) -> Engine: + """Get the synchronous version of the engine.""" + return get_sync_engine(self.engine) + + def __enter__(self) -> Self: + """Enter a ``with`` statement.""" + return self + + def __exit__( + self, + _exc_type: Optional[Type[BaseException]], + _exc_val: Optional[BaseException], + _exc_tb: Optional[TracebackType], + ) -> None: + """Dispose of this object.""" + self.engine.dispose() + + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Print text, formatted with positional and keyword arguments.""" + print(text.format(*args, **kwargs)) + + def print_table( + self, headings: Sequence[str], rows: Sequence[Sequence[Any]] + ) -> None: + """ + Print a table. + + :param headings: List of headings for the table. + :param rows: List of rows of values. + """ + output = PrettyTable() + output.field_names = headings + for row in rows: + # Hopefully PrettyTable will accept Sequence in the future, not list + output.add_row(list(row)) + print(output) + + def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: + """ + Print a table. + + :param columns: Dict of column names to the values in the column. + """ + output = PrettyTable() + row_count = max([len(col) for col in columns.values()]) + for field_name, data in columns.items(): + output.add_column(field_name, list(data) + [None] * (row_count - len(data))) + print(output) + + def print_results(self, result: sqlalchemy.CursorResult) -> None: + """Print the rows resulting from a database query.""" + self.print_table(list(result.keys()), [list(row) for row in result.all()]) + + def ask_save(self) -> str: + """ + Ask the user if they want to save. + + :return: ``yes``, ``no`` or ``cancel``. + """ + ask = AskSaveCmd() + ask.cmdloop() + return ask.result + + @abstractmethod + def set_prompt(self) -> None: + """Set the prompt according to the current state.""" + ... + + def _set_table_index(self, index: int) -> bool: + """ + Move to a different table. + + :param index: Index of the table to move to. + :return: True if there is a table with such an index to move to. + """ + if 0 <= index < len(self._table_entries): + self.table_index = index + self.set_prompt() + return True + return False + + def next_table(self, report: str = "No more tables") -> bool: + """ + Move to the next table. + + :param report: The text to print if there is no next table. + :return: True if there is another table to move to. + """ + if not self._set_table_index(self.table_index + 1): + self.print(report) + return False + return True + + def table_name(self) -> str: + """Get the name of the current table.""" + return str(self._table_entries[self.table_index].name) + + def table_metadata(self) -> Table: + """Get the metadata of the current table.""" + return self.metadata.tables[self.table_name()] + + def _get_column_names(self) -> list[str]: + """Get the names of the current columns.""" + return [col.name for col in self.table_metadata().columns] + + def report_columns(self) -> None: + """Print information about the current columns.""" + self.print_table( + ["name", "type", "primary", "nullable", "foreign key"], + [ + [ + name, + str(col.type), + col.primary_key, + col.nullable, + ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), + ] + for name, col in self.table_metadata().columns.items() + ], + ) + + def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: + """Get the configuration of the named table.""" + ts = self.config.get("tables", None) + if not isinstance(ts, MutableMapping): + return {} + t = ts.get(table_name) + return t if isinstance(t, MutableMapping) else {} + + def set_table_config( + self, table_name: str, config: MutableMapping[str, Any] + ) -> None: + """Set the configuration of the named table.""" + ts = self.config.get("tables", None) + if not isinstance(ts, MutableMapping): + self.config["tables"] = {table_name: config} + return + ts[table_name] = config + + def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: + """Remove all source stats with the given prefix from the configuration.""" + src_stats = self.config.get("src-stats", []) + new_src_stats = [] + for stat in src_stats: + if not stat.get("name", "").startswith(prefix): + new_src_stats.append(stat) + self.config["src-stats"] = new_src_stats + return new_src_stats + + def get_nonnull_columns(self, table_name: str) -> list[str]: + """Get the names of the nullable columns in the named table.""" + metadata_table = self.metadata.tables[table_name] + return [ + str(name) + for name, column in metadata_table.columns.items() + if column.nullable + ] + + def find_entry_index_by_table_name(self, table_name: str) -> int | None: + """Get the index of the table entry of the named table.""" + return next( + ( + i + for i, entry in enumerate(self._table_entries) + if entry.name == table_name + ), + None, + ) + + def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + """Get the table entry of the named table.""" + for e in self._table_entries: + if e.name == table_name: + return e + return None + + def do_counts(self, _arg: str) -> None: + """Report the column names with the counts of nulls in them.""" + if len(self._table_entries) <= self.table_index: + return + table_name = self.table_name() + nonnull_columns = self.get_nonnull_columns(table_name) + colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns] + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + ) + ).first() + if result is None: + self.print("Could not count rows in table {0}", table_name) + return + row_count = result.row_count + self.print(self.ROW_COUNT_MSG, row_count) + self.print_table( + ["Column", "NULL count"], + [ + [name, row_count - count] + for name, count in result._mapping.items() + if name != "row_count" + ], + ) + + def do_select(self, arg: str) -> None: + """Run a select query over the database and show the first 50 results.""" + max_select_rows = 50 + with self.sync_engine.connect() as connection: + try: + result = connection.execute(sqlalchemy.text("SELECT " + arg)) + except sqlalchemy.exc.DatabaseError as exc: + self.print("Failed to execute: {}", exc) + return + row_count = result.rowcount + self.print(self.ROW_COUNT_MSG, row_count) + if 50 < row_count: + self.print("Showing the first {} rows", max_select_rows) + fields = list(result.keys()) + rows = result.fetchmany(max_select_rows) + self.print_table(fields, rows) + + def do_peek(self, arg: str) -> None: + """ + View some data from the current table. + + Use 'peek col1 col2 col3' to see a sample of values from + columns col1, col2 and col3 in the current table. + Use 'peek' to see a sample of the current column(s). + Rows that are enitrely null are suppressed. + """ + max_peek_rows = 25 + if len(self._table_entries) <= self.table_index: + return + table_name = self.table_name() + col_names = arg.split() + if not col_names: + col_names = self._get_column_names() + nonnulls = [cn + " IS NOT NULL" for cn in col_names] + with self.sync_engine.connect() as connection: + cols = ",".join(col_names) + where = "WHERE" if nonnulls else "" + nonnull = " OR ".join(nonnulls) + query = sqlalchemy.text( + f"SELECT {cols} FROM {table_name} {where} {nonnull}" + f" ORDER BY RANDOM() LIMIT {max_peek_rows}" + ) + try: + result = connection.execute(query) + except Exception as exc: + self.print(f'SQL query "{query}" caused exception {exc}') + return + self.print_table(list(result.keys()), result.fetchmany(max_peek_rows)) + + def complete_peek( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Completions for the ``peek`` command.""" + if len(self._table_entries) <= self.table_index: + return [] + return [ + col for col in self.table_metadata().columns.keys() if col.startswith(text) + ] diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py new file mode 100644 index 00000000..4c87c6b2 --- /dev/null +++ b/datafaker/interactive/generators.py @@ -0,0 +1,980 @@ +"""Generator configuration shell.""" +from dataclasses import dataclass +from collections.abc import Mapping, Sequence, Iterable, MutableMapping +import functools +import re +from typing import Any, Optional, cast, Callable + +import sqlalchemy +from sqlalchemy import Column, MetaData + +from datafaker.generators import everything_factory +from datafaker.generators.base import Generator, PredefinedGenerator +from datafaker.interactive.base import TableEntry, DbCmd, fk_column_name, or_default +from datafaker.utils import logger, table_is_private, primary_private_fks + +@dataclass +class GeneratorInfo: + """A generator and the columns it assigns to.""" + + columns: list[str] + gen: Generator | None + + +@dataclass +class GeneratorCmdTableEntry(TableEntry): + """ + List of generators set for a table. + + Includes the original setting and the currently configured + generators. + """ + + old_generators: list[GeneratorInfo] + new_generators: list[GeneratorInfo] + + +class GeneratorCmd(DbCmd): + """Interactive command shell for setting generators.""" + + intro = "Interactive generator configuration. Type ? for help.\n" + doc_leader = """Use command 'propose' for a list of generators applicable to the +current column, then command 'compare' to see how these perform +against the source data, then command 'set' to choose your favourite. +Use 'unset' to remove the column's generator. Use commands 'next' and +'previous' to change which column we are examining. Use 'info' +for useful information about the current column. Use 'tables' and +'list' to see available tables and columns. Use 'columns' to see +information about the columns in the current table. Use 'peek', +'count' or 'select' to fetch data from the source database. Use +'quit' to exit this program.""" + prompt = "(generatorconf) " + file = None + + PROPOSE_SOURCE_SAMPLE_TEXT = "Sample of actual source data: {0}..." + PROPOSE_SOURCE_EMPTY_TEXT = "Source database has no data in this column." + PROPOSE_GENERATOR_SAMPLE_TEXT = "{index}. {name}: {fit} {sample} ..." + PRIMARY_PRIVATE_TEXT = "Primary Private" + SECONDARY_PRIVATE_TEXT = "Secondary Private on columns {0}" + NOT_PRIVATE_TEXT = "Not private" + ERROR_NO_SUCH_TABLE = "No such (non-vocabulary, non-ignored) table name {0}" + ERROR_NO_SUCH_COLUMN = "No such column {0} in this table" + ERROR_COLUMN_ALREADY_MERGED = "Column {0} is already merged" + ERROR_COLUMN_ALREADY_UNMERGED = "Column {0} is not merged" + ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" + PROPOSE_NOTHING = "No proposed generators, sorry." + + SRC_STAT_RE = re.compile( + r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' + ) + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> GeneratorCmdTableEntry | None: + """ + Make a table entry. + + :param table_name: The name of the table. + :param table: The portion of the ``config.yaml`` file describing this table. + :return: The newly constructed table entry, or None if this table is to be ignored. + """ + if table_config.get("ignore", False): + return None + if table_config.get("vocabulary_table", False): + return None + if table_config.get("num_rows_per_pass", 1) == 0: + return None + metadata_table = self.metadata.tables[table_name] + columns = [str(colname) for colname in metadata_table.columns.keys()] + column_set = frozenset(columns) + columns_assigned_so_far: set[str] = set() + + new_generator_infos: list[GeneratorInfo] = [] + old_generator_infos: list[GeneratorInfo] = [] + for rg in table_config.get("row_generators", []): + gen_name = rg.get("name", None) + if gen_name: + ca = rg.get("columns_assigned", []) + collist: list[str] = ( + [ca] if isinstance(ca, str) else [str(c) for c in ca] + ) + colset: set[str] = set(collist) + for unknown in colset - column_set: + logger.warning( + "table '%s' has '%s' assigned to column '%s' which is not in this table", + table_name, + gen_name, + unknown, + ) + for mult in columns_assigned_so_far & colset: + logger.warning( + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, + ) + actual_collist = [c for c in collist if c in columns] + if actual_collist: + gen = PredefinedGenerator(table_name, rg, self.config) + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) + old_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) + columns_assigned_so_far |= colset + for colname in columns: + if colname not in columns_assigned_so_far: + new_generator_infos.append( + GeneratorInfo( + columns=[colname], + gen=None, + ) + ) + if len(new_generator_infos) == 0: + return None + return GeneratorCmdTableEntry( + name=table_name, + old_generators=old_generator_infos, + new_generators=new_generator_infos, + ) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ) -> None: + """ + Initialise a ``GeneratorCmd``. + + :param src_dsn: connection address for source database + :param src_schema: database schema name + :param metadata: SQLAlchemy metadata for the source database + :param config: Configuration loaded from ``config.yaml`` + """ + super().__init__(src_dsn, src_schema, metadata, config) + self.generators: list[Generator] | None = None + self.generator_index = 0 + self.generators_valid_columns: Optional[tuple[int, list[str]]] = None + self.set_prompt() + + @property + def table_entries(self) -> list[GeneratorCmdTableEntry]: + """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" + return cast(list[GeneratorCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name( + self, table_name: str + ) -> GeneratorCmdTableEntry | None: + """ + Find the table entry by name. + + :param table_name: The name of the table to find. + :return: The table entry, or None if no such table name exists. + """ + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(GeneratorCmdTableEntry, entry) + + def _set_table_index(self, index: int) -> bool: + """ + Move to a new table. + + :param index: table index to move to. + """ + ret = super()._set_table_index(index) + if ret: + self.generator_index = 0 + self.set_prompt() + return ret + + def _previous_table(self) -> bool: + """ + Move to the table before the current one. + + :return: True if there is a previous table to go to. + """ + ret = self._set_table_index(self.table_index - 1) + if ret: + table = self.get_table() + if table is None: + self.print( + "Internal error! table {0} does not have any generators!", + self.table_index, + ) + return False + self.generator_index = len(table.new_generators) - 1 + else: + self.print(self.ERROR_ALREADY_AT_START) + return ret + + def get_table(self) -> GeneratorCmdTableEntry | None: + """Get the current table entry.""" + if self.table_index < len(self.table_entries): + return self.table_entries[self.table_index] + return None + + def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + """Get a pair; the table name then the generator information.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + if self.generator_index < len(entry.new_generators): + return (entry.name, entry.new_generators[self.generator_index]) + return (entry.name, None) + return (None, None) + + def _get_column_names(self) -> list[str]: + """Get the (unqualified) names for all the current columns.""" + (_, generator_info) = self._get_table_and_generator() + return generator_info.columns if generator_info else [] + + def _column_metadata(self) -> list[Column]: + """Get the metadata for all the current columns.""" + table = self.table_metadata() + if table is None: + return [] + return [table.columns[name] for name in self._get_column_names()] + + def set_prompt(self) -> None: + """Set the prompt according to the current table, column and generator.""" + (table_name, gen_info) = self._get_table_and_generator() + if table_name is None: + self.prompt = "(generators) " + return + if gen_info is None: + self.prompt = f"({table_name}) " + return + table = self.table_metadata() + columns = [ + c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns + ] + gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" + self.prompt = f"({table_name}.{','.join(columns)}{gen}) " + + def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: + """ + Remove all automatic source stats. + + We assume every source stats query whose name begins with ``auto__` + :return: The new ``src_stats`` configuration. + """ + return self._remove_prefix_src_stats("auto__") + + def _copy_entries(self) -> None: + """Set generator and query information in the configuration.""" + src_stats = self._remove_auto_src_stats() + for entry in self.table_entries: + rgs = [] + new_gens: list[Generator] = [] + for generator in entry.new_generators: + if generator.gen is not None: + new_gens.append(generator.gen) + cqs = generator.gen.custom_queries() + for cq_key, cq in cqs.items(): + src_stats.append( + { + "name": cq_key, + "query": cq["query"], + "comments": [cq["comment"]] + if "comment" in cq and cq["comment"] + else [], + } + ) + rg: dict[str, Any] = { + "name": generator.gen.function_name(), + "columns_assigned": generator.columns, + } + kwn = generator.gen.nominal_kwargs() + if kwn: + rg["kwargs"] = kwn + rgs.append(rg) + aq = self._get_aggregate_query(new_gens, entry.name) + if aq: + src_stats.append( + { + "name": f"auto__{entry.name}", + "query": aq, + "comments": [ + q["comment"] + for gen in new_gens + for q in gen.select_aggregate_clauses().values() + if "comment" in q and q["comment"] is not None + ], + } + ) + table_config = self.get_table_config(entry.name) + if rgs: + table_config["row_generators"] = rgs + elif "row_generators" in table_config: + del table_config["row_generators"] + self.set_table_config(entry.name, table_config) + self.config["src-stats"] = src_stats + + def _find_old_generator( + self, entry: GeneratorCmdTableEntry, columns: Iterable[str] + ) -> Generator | None: + """Find any generator that previously assigned to these exact same columns.""" + fc = frozenset(columns) + for gen in entry.old_generators: + if frozenset(gen.columns) == fc: + return gen.gen + return None + + def do_quit(self, arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + header_shown = False + g_entry = cast(GeneratorCmdTableEntry, entry) + for gen in g_entry.new_generators: + old_gen = self._find_old_generator(g_entry, gen.columns) + new_gen = None if gen is None else gen.gen + if old_gen != new_gen: + if not header_shown: + header_shown = True + self.print("Table {0}:", entry.name) + count += 1 + self.print( + "...changing {0} from {1} to {2}", + ", ".join(gen.columns), + old_gen.name() if old_gen else "nothing", + gen.gen.name() if gen.gen else "nothing", + ) + if count == 0: + self.print("You have made no changes.") + if arg in {"yes", "no"}: + reply = arg + else: + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables.""" + for t_entry in self.table_entries: + entry = cast(GeneratorCmdTableEntry, t_entry) + gen_count = len(entry.new_generators) + how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" + self.print("{0} ({1})", entry.name, how_many) + + def do_list(self, _arg: str) -> None: + """List the generators in the current table.""" + if len(self.table_entries) <= self.table_index: + self.print("Error: no table {0}", self.table_index) + return + g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) + table = self.table_metadata() + for gen in g_entry.new_generators: + old_gen = self._find_old_generator(g_entry, gen.columns) + old = "" if old_gen is None else old_gen.name() + if old_gen == gen.gen: + becomes = "" + if old == "": + old = "(not set)" + elif gen.gen is None: + becomes = "(delete)" + else: + becomes = f"->{gen.gen.name()}" + primary = "" + if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: + primary = "[primary-key]" + self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) + + def do_columns(self, _arg: str) -> None: + """Report the column names and metadata.""" + self.report_columns() + + def do_info(self, _arg: str) -> None: + """Show information about the current column.""" + for cm in self._column_metadata(): + self.print( + "Column {0} in table {1} has type {2} ({3}).", + cm.name, + cm.table.name, + str(cm.type), + "nullable" if cm.nullable else "not nullable", + ) + if cm.primary_key: + self.print( + "It is a primary key, which usually does not" + " need a generator (it will auto-increment)" + ) + if cm.foreign_keys: + fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] + self.print( + "It is a foreign key referencing column {0}", ", ".join(fk_names) + ) + if len(fk_names) == 1 and not cm.primary_key: + self.print( + "You do not need a generator if you just want" + " a uniform choice over the referenced table's rows" + ) + + def _get_table_index(self, table_name: str) -> int | None: + """Get the index of the named table in the table entries list.""" + for n, entry in enumerate(self.table_entries): + if entry.name == table_name: + return n + return None + + def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + """ + Get the index number of a column within the list of generators in this table. + + :param table_index: The index of the table in which to search. + :param column_name: The name of the column to search for. + :return: The index in the ``new_generators`` attribute of the table entry + containing the specified column, or None if this does not exist. + """ + entry = self.table_entries[table_index] + for n, gen in enumerate(entry.new_generators): + if column_name in gen.columns: + return n + return None + + def go_to(self, target: str) -> bool: + """ + Go to a particular column. + + :return: True on success. + """ + parts = target.split(".", 1) + table_index = self._get_table_index(parts[0]) + if table_index is None: + if len(parts) == 1: + gen_index = self._get_generator_index(self.table_index, parts[0]) + if gen_index is not None: + self.generator_index = gen_index + self.set_prompt() + return True + self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, parts[0]) + return False + gen_index = None + if 1 < len(parts) and parts[1]: + gen_index = self._get_generator_index(table_index, parts[1]) + if gen_index is None: + self.print("we cannot set the generator for column {0}", parts[1]) + return False + self._set_table_index(table_index) + if gen_index is not None: + self.generator_index = gen_index + self.set_prompt() + return True + + def do_next(self, arg: str) -> None: + """ + Go to the next generator. or a specified generator. + + Go to a named table: 'next tablename', + go to a column: 'next tablename.columnname', + or go to a column within this table: 'next columnname'. + """ + if arg: + self.go_to(arg) + else: + self._go_next() + + def do_n(self, arg: str) -> None: + """Go to the next generator, or a specified generator.""" + self.do_next(arg) + + def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: + """Complete the ``n`` command's arguments.""" + return self.complete_next(text, line, begidx, endidx) + + def _go_next(self) -> None: + """Go to the next column.""" + table = self.get_table() + if table is None: + self.print("No more tables") + return + next_gi = self.generator_index + 1 + if next_gi == len(table.new_generators): + self.next_table(self.INFO_NO_MORE_TABLES) + return + self.generator_index = next_gi + self.set_prompt() + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Completions for the arguments of the ``next`` command.""" + parts = text.split(".", 1) + first_part = parts[0] + if 1 < len(parts): + column_name = parts[1] + table_index = self._get_table_index(first_part) + if table_index is None: + return [] + table_entry = self.table_entries[table_index] + return [ + f"{first_part}.{column}" + for gen in table_entry.new_generators + for column in gen.columns + if column.startswith(column_name) + ] + table_names = [ + entry.name + for entry in self.table_entries + if entry.name.startswith(first_part) + ] + if first_part in table_names: + table_names.append(f"{first_part}.") + current_table = self.get_table() + if current_table: + column_names = [ + col + for gen in current_table.new_generators + for col in gen.columns + if col.startswith(first_part) + ] + else: + column_names = [] + return table_names + column_names + + def do_previous(self, _arg: str) -> None: + """Go to the previous generator.""" + if self.generator_index == 0: + self._previous_table() + else: + self.generator_index -= 1 + self.set_prompt() + + def do_b(self, arg: str) -> None: + """Synonym for previous.""" + self.do_previous(arg) + + def _generators_valid(self) -> bool: + """Test if ``self.generators`` is still correct for the current columns.""" + return self.generators_valid_columns == ( + self.table_index, + self._get_column_names(), + ) + + def _get_generator_proposals(self) -> list[Generator]: + """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" + if not self._generators_valid(): + self.generators = None + if self.generators is None: + columns = self._column_metadata() + gens = everything_factory().get_generators(columns, self.sync_engine) + sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) + self.generators = sorted_gens + self.generators_valid_columns = ( + self.table_index, + self._get_column_names().copy(), + ) + return self.generators + + def _print_privacy(self) -> None: + """Print the privacy status of the current table.""" + table = self.table_metadata() + if table is None: + return + if table_is_private(self.config, table.name): + self.print(self.PRIMARY_PRIVATE_TEXT) + return + pfks = primary_private_fks(self.config, table) + if not pfks: + self.print(self.NOT_PRIVATE_TEXT) + return + self.print(self.SECONDARY_PRIVATE_TEXT, pfks) + + def do_compare(self, arg: str) -> None: + """ + Compare the real data with some generators. + + 'compare': just look at some source data from this column. + 'compare 5 6 10': compare a sample of the source data with a sample + from generators 5, 6 and 10. You can find out which numbers + correspond to which generators using the 'propose' command. + """ + self._print_privacy() + args = arg.split() + limit = 20 + comparison = { + "source": [ + x[0] if len(x) == 1 else ", ".join(x) + for x in self._get_column_data(limit, to_str=str) + ] + } + gens: list[Generator] = self._get_generator_proposals() + table_name = self.table_name() + for argument in args: + if argument.isdigit(): + n = int(argument) + if 0 < n <= len(gens): + gen = gens[n - 1] + comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) + self._print_values_queried(table_name, n, gen) + self.print_table_by_columns(comparison) + + def do_c(self, arg: str) -> None: + """Synonym for compare.""" + self.do_compare(arg) + + def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: + """ + Print the values queried from the database for this generator. + + :param table_name: The name of the table the generator applies to. + :param n: A number to print at the start of the output. + :param gen: The generator to report. + """ + if not gen.select_aggregate_clauses() and not gen.custom_queries(): + self.print( + "{0}. {1} requires no data from the source database.", + n, + gen.name(), + ) + else: + self.print( + "{0}. {1} requires the following data from the source database:", + n, + gen.name(), + ) + self._print_select_aggregate_query(table_name, gen) + self._print_custom_queries(gen) + + def _print_custom_queries(self, gen: Generator) -> None: + """ + Print all the custom queries and all the values they get in this case. + + :param gen: The generator to print the custom queries for. + """ + cqs = gen.custom_queries() + if not cqs: + return + cq_key2args: dict[str, Any] = {} + nominal = gen.nominal_kwargs() + actual = gen.actual_kwargs() + self._get_custom_queries_from( + cq_key2args, + nominal, + actual, + ) + for cq_key, cq in cqs.items(): + self.print( + "{0}; providing the following values: {1}", + cq["query"], + cq_key2args[cq_key], + ) + + def _get_custom_queries_from( + self, out: dict[str, Any], nominal: Any, actual: Any + ) -> None: + if isinstance(nominal, str): + src_stat_groups = self.SRC_STAT_RE.search(nominal) + # Do we have a SRC_STAT reference? + if src_stat_groups: + # Get its name + cq_key = src_stat_groups.group(1) + # Are we pulling a specific part of this result? + sub = src_stat_groups.group(3) + if sub: + actual = {sub: actual} + else: + out[cq_key] = actual + elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): + for i in range(min(len(nominal), len(actual))): + self._get_custom_queries_from(out, nominal[i], actual[i]) + elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): + for k, v in nominal.items(): + if k in actual: + self._get_custom_queries_from(out, v, actual[k]) + + def _get_aggregate_query( + self, gens: list[Generator], table_name: str + ) -> str | None: + clauses = [ + f'{q["clause"]} AS {n}' + for gen in gens + for n, q in or_default(gen.select_aggregate_clauses(), {}).items() + ] + if not clauses: + return None + return f"SELECT {', '.join(clauses)} FROM {table_name}" + + def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: + """ + Print the select aggregate query and all the values it gets in this case. + + This is not the entire query that will be executed, but only the part of it + that is required by a certain generator. + :param table_name: The table name. + :param gen: The generator to limit the aggregate query to. + """ + sacs = gen.select_aggregate_clauses() + if not sacs: + return + kwa = gen.actual_kwargs() + vals = [] + src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} + for n in sacs.keys(): + src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' + if src_stat in src_stat2kwarg: + ak = src_stat2kwarg[src_stat] + if ak in kwa: + vals.append(kwa[ak]) + else: + logger.warning( + "actual_kwargs for %s does not report %s", gen.name(), ak + ) + else: + logger.warning( + ( + "nominal_kwargs for %s does not have a value" + ' SRC_STATS["auto__%s"]["results"][0]["%s"]' + ), + gen.name(), + table_name, + n, + ) + select_q = self._get_aggregate_query([gen], table_name) + self.print("{0}; providing the following values: {1}", select_q, vals) + + def _get_column_data( + self, count: int, to_str: Callable[[Any], str] = repr + ) -> list[list[str]]: + columns = self._get_column_names() + columns_string = ", ".join(columns) + pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT {columns_string} FROM {self.table_name()}" + f" WHERE {pred} ORDER BY RANDOM() LIMIT {count}" + ) + ) + return [[to_str(x) for x in xs] for xs in result.all()] + + def do_propose(self, _arg: str) -> None: + """ + Display a list of possible generators for this column. + + They will be listed in order of fit, the most likely matches first. + The results can be compared (against a sample of the real data in + the column and against each other) with the 'compare' command. + """ + limit = 5 + gens = self._get_generator_proposals() + sample = self._get_column_data(limit) + if sample: + rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] + self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) + else: + self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) + if not gens: + self.print(self.PROPOSE_NOTHING) + for index, gen in enumerate(gens): + fit = gen.fit(-1) + if fit == -1: + fit_s = "(no fit)" + elif fit < 100: + fit_s = f"(fit: {fit:.3g})" + else: + fit_s = f"(fit: {fit:.0f})" + self.print( + self.PROPOSE_GENERATOR_SAMPLE_TEXT, + index=index + 1, + name=gen.name(), + fit=fit_s, + sample="; ".join(map(repr, gen.generate_data(limit))), + ) + + def do_p(self, arg: str) -> None: + """Synonym for propose.""" + self.do_propose(arg) + + def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: + """Find a generator by name from the list of proposals.""" + for gen in self._get_generator_proposals(): + if gen.name() == gen_name: + return gen + return None + + def do_set(self, arg: str) -> None: + """Set one of the proposals as a generator.""" + if arg.isdigit() and not self._generators_valid(): + self.print("Please run 'propose' before 'set '") + return + gens = self._get_generator_proposals() + new_gen: Generator | None + if arg.isdigit(): + index = int(arg) + if index < 1: + self.print("set's integer argument must be at least 1") + return + if len(gens) < index: + self.print( + "There are currently only {0} generators proposed, please select one of them.", + len(gens), + ) + return + new_gen = gens[index - 1] + else: + new_gen = self.get_proposed_generator_by_name(arg) + if new_gen is None: + self.print("'{0}' is not an appropriate generator for this column", arg) + return + self.set_generator(new_gen) + self._go_next() + + def set_generator(self, gen: Generator | None) -> None: + """Set the current column's generator.""" + (table, gen_info) = self._get_table_and_generator() + if table is None: + self.print("Error: no table") + return + if gen_info is None: + self.print("Error: no column") + return + gen_info.gen = gen + + def do_s(self, arg: str) -> None: + """Synonym for set.""" + self.do_set(arg) + + def do_unset(self, _arg: str) -> None: + """Remove any generator set for this column.""" + self.set_generator(None) + self._go_next() + + def do_merge(self, arg: str) -> None: + """ + Add this column(s) to the specified column(s). + + After this, one generator will cover them all. + """ + cols = arg.split() + if not cols: + self.print("Error: merge requires a column argument") + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + self.print(self.ERROR_NO_SUCH_TABLE) + return + cols_available = functools.reduce( + lambda x, y: x | y, + [frozenset(gen.columns) for gen in table_entry.new_generators], + ) + cols_to_merge = frozenset(cols) + unknown_cols = cols_to_merge - cols_available + if unknown_cols: + for uc in unknown_cols: + self.print(self.ERROR_NO_SUCH_COLUMN, uc) + return + gen_info = table_entry.new_generators[self.generator_index] + current_columns = frozenset(gen_info.columns) + stated_current_columns = cols_to_merge & current_columns + if stated_current_columns: + for c in stated_current_columns: + self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) + return + # Remove cols_to_merge from each generator + new_new_generators: list[GeneratorInfo] = [] + for gen in table_entry.new_generators: + if gen is gen_info: + # Add columns to this generator + self.generator_index = len(new_new_generators) + new_new_generators.append( + GeneratorInfo( + columns=gen.columns + cols, + gen=None, + ) + ) + else: + # Remove columns if applicable + new_columns = [c for c in gen.columns if c not in cols_to_merge] + is_changed = len(new_columns) != len(gen.columns) + if new_columns: + # We have not removed this generator completely + new_new_generators.append( + GeneratorInfo( + columns=new_columns, + gen=None if is_changed else gen.gen, + ) + ) + table_entry.new_generators = new_new_generators + self.set_prompt() + + def complete_merge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Complete column names.""" + last_arg = text.split()[-1] + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return [] + return [ + column + for i, gen in enumerate(table_entry.new_generators) + if i != self.generator_index + for column in gen.columns + if column.startswith(last_arg) + ] + + def do_unmerge(self, arg: str) -> None: + """Remove this column(s) from this generator, make them a separate generator.""" + cols = arg.split() + if not cols: + self.print("Error: merge requires a column argument") + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + self.print(self.ERROR_NO_SUCH_TABLE) + return + gen_info = table_entry.new_generators[self.generator_index] + current_columns = frozenset(gen_info.columns) + cols_to_unmerge = frozenset(cols) + unknown_cols = cols_to_unmerge - current_columns + if unknown_cols: + for uc in unknown_cols: + self.print(self.ERROR_NO_SUCH_COLUMN, uc) + return + stated_unmerged_columns = cols_to_unmerge - current_columns + if stated_unmerged_columns: + for c in stated_unmerged_columns: + self.print(self.ERROR_COLUMN_ALREADY_UNMERGED, c) + return + if cols_to_unmerge == current_columns: + self.print(self.ERROR_CANNOT_UNMERGE_ALL) + return + # Remove unmerged columns + for um in cols_to_unmerge: + gen_info.columns.remove(um) + # The existing generator will not work + gen_info.gen = None + # And put them into a new (empty) generator + table_entry.new_generators.insert( + self.generator_index + 1, + GeneratorInfo( + columns=cols, + gen=None, + ), + ) + self.set_prompt() + + def complete_unmerge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Complete column names to unmerge.""" + last_arg = text.split()[-1] + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return [] + return [ + column + for column in table_entry.new_generators[self.generator_index].columns + if column.startswith(last_arg) + ] diff --git a/datafaker/interactive/missingness.py b/datafaker/interactive/missingness.py new file mode 100644 index 00000000..2737e2a8 --- /dev/null +++ b/datafaker/interactive/missingness.py @@ -0,0 +1,355 @@ +"""Missingness configuration shell.""" +from dataclasses import dataclass +from collections.abc import Iterable, Mapping, MutableMapping +import re +from typing import cast + +from sqlalchemy import MetaData + +from datafaker.interactive.base import DbCmd, TableEntry + +@dataclass +class MissingnessType: + """The functions required for applying missingness.""" + + SAMPLED = "column_presence.sampled" + SAMPLED_QUERY = ( + "SELECT COUNT(*) AS row_count, {result_names} FROM " + "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" + " AS __t GROUP BY {result_names}" + ) + name: str + query: str + comment: str + columns: list[str] + + @classmethod + def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: + """ + Construct a query to make a sampling of the named rows of the table. + + :param table: The name of the table to sample. + :param count: The number of samples to get. + :param column_names: The columns to fetch. + :return: The SQL query to do the sampling. + """ + result_names = ", ".join([f"{c}__is_null" for c in column_names]) + column_is_nulls = ", ".join( + [f"{c} IS NULL AS {c}__is_null" for c in column_names] + ) + return cls.SAMPLED_QUERY.format( + result_names=result_names, + column_is_nulls=column_is_nulls, + table=table, + count=count, + ) + + +@dataclass +class MissingnessCmdTableEntry(TableEntry): + """Table entry for the missingness command shell.""" + + old_type: MissingnessType + new_type: MissingnessType | None + + +class MissingnessCmd(DbCmd): + """ + Interactive shell for the user to set missingness. + + Can only be used for Missingness Completely At Random. + """ + + intro = "Interactive missingness configuration. Type ? for help.\n" + doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for +the current table. Use commands 'next' and 'previous' to change the +current table. Use 'tables' to list the tables and 'count' to show +how many NULLs exist in each column. Use 'peek' or 'select' to see +data from the database. Use 'quit' to exit this tool.""" + prompt = "(missingness) " + file = None + PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') + + def find_missingness_query( + self, missingness_generator: Mapping + ) -> tuple[str, str] | None: + """Find query and comment from src-stats for the passed missingness generator.""" + kwargs = missingness_generator.get("kwargs", {}) + patterns = kwargs.get("patterns", "") + pattern_match = self.PATTERN_RE.match(patterns) + if pattern_match: + key = pattern_match.group(1) + for src_stat in self.config["src-stats"]: + if src_stat.get("name") == key: + query = src_stat.get("query", None) + if not isinstance(query, str): + return None + return (query, src_stat.get("comment", "")) + return None + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> MissingnessCmdTableEntry | None: + """ + Make a table entry for a particular table. + + :param name: The name of the table to make an entry for. + :param table: The part of ``config.yaml`` relating to this table. + :return: The newly-constructed table entry. + """ + if table_config.get("ignore", False): + return None + if table_config.get("vocabulary_table", False): + return None + if table_config.get("num_rows_per_pass", 1) == 0: + return None + mgs = table_config.get("missingness_generators", []) + old = None + nonnull_columns = self.get_nonnull_columns(table_name) + if not nonnull_columns: + return None + if not mgs: + old = MissingnessType( + name="none", + query="", + comment="", + columns=[], + ) + elif len(mgs) == 1: + mg = mgs[0] + mg_name = mg.get("name", None) + if isinstance(mg_name, str): + query_comment = self.find_missingness_query(mg) + if query_comment is not None: + (query, comment) = query_comment + old = MissingnessType( + name=mg_name, + query=query, + comment=comment, + columns=mg.get("columns_assigned", []), + ) + if old is None: + return None + return MissingnessCmdTableEntry( + name=table_name, + old_type=old, + new_type=old, + ) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping, + ): + """ + Initialise a MissingnessCmd. + + :param src_dsn: connection string for the source database. + :param src_schema: schema name for the source database. + :param metadata: SQLAlchemy metadata for the source database. + :param config: Configuration from the ``config.yaml`` file. + """ + super().__init__(src_dsn, src_schema, metadata, config) + self.set_prompt() + + @property + def table_entries(self) -> list[MissingnessCmdTableEntry]: + """Get the table entries list.""" + return cast(list[MissingnessCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name( + self, table_name: str + ) -> MissingnessCmdTableEntry | None: + """Find the table entry given the table name.""" + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(MissingnessCmdTableEntry, entry) + + def set_prompt(self) -> None: + """Set the prompt according to the current table and missingness.""" + if self.table_index < len(self.table_entries): + entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + nt = entry.new_type + if nt is None: + self.prompt = f"(missingness for {entry.name}) " + else: + self.prompt = f"(missingness for {entry.name}: {nt.name}) " + else: + self.prompt = "(missingness) " + + def set_type(self, t_type: MissingnessType) -> None: + """Set the missingness of the current table.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + entry.new_type = t_type + + def _copy_entries(self) -> None: + """Set the new missingness into the configuration.""" + src_stats = self._remove_prefix_src_stats("missing_auto__") + for entry in self.table_entries: + table = self.get_table_config(entry.name) + if entry.new_type is None or entry.new_type.name == "none": + table.pop("missingness_generators", None) + else: + src_stat_key = f"missing_auto__{entry.name}__0" + table["missingness_generators"] = [ + { + "name": entry.new_type.name, + "kwargs": { + "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' + }, + "columns": entry.new_type.columns, + } + ] + src_stats.append( + { + "name": src_stat_key, + "query": entry.new_type.query, + "comments": [] + if entry.new_type.comment is None + else [entry.new_type.comment], + } + ) + self.set_table_config(entry.name, table) + + def do_quit(self, _arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + if entry.old_type != entry.new_type: + count += 1 + if entry.old_type is None: + self.print( + "Putting generator {0} on table {1}", + entry.name, + entry.new_type.name, + ) + elif entry.new_type is None: + self.print( + "Deleting generator {1} from table {0}", + entry.name, + entry.old_type.name, + ) + else: + self.print( + "Changing {0} from {1} to {2}", + entry.name, + entry.old_type.name, + entry.new_type.name, + ) + if count == 0: + self.print("You have made no changes.") + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables with their types.""" + for entry in self.table_entries: + old = "-" if entry.old_type is None else entry.old_type.name + new = "-" if entry.new_type is None else entry.new_type.name + desc = new if old == new else f"{old}->{new}" + self.print("{0} {1}", entry.name, desc) + + def do_next(self, arg: str) -> None: + """ + Go to the next table, or a specified table. + + 'next' = go to the next table, 'next tablename' = go to table 'tablename' + """ + if arg: + # Find the index of the table called _arg, if any + index = next( + (i for i, entry in enumerate(self.table_entries) if entry.name == arg), + None, + ) + if index is None: + self.print(self.ERROR_NO_SUCH_TABLE, arg) + return + self._set_table_index(index) + return + self.next_table(self.INFO_NO_MORE_TABLES) + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Get completions for tables and columns.""" + return [ + entry.name for entry in self.table_entries if entry.name.startswith(text) + ] + + def do_previous(self, _arg: str) -> None: + """Go to the previous table.""" + if not self._set_table_index(self.table_index - 1): + self.print(self.ERROR_ALREADY_AT_START) + + def _set_type(self, name: str, query: str, comment: str) -> None: + """Set the current table entry's query.""" + if len(self.table_entries) <= self.table_index: + return + entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + entry.new_type = MissingnessType( + name=name, + query=query, + comment=comment, + columns=self.get_nonnull_columns(entry.name), + ) + + def _set_none(self) -> None: + """Set the current table to have no missingness applied.""" + if len(self.table_entries) <= self.table_index: + return + self.table_entries[self.table_index].new_type = None + + def do_sampled(self, arg: str) -> None: + """ + Set the current table missingness as 'sampled', and go to the next table. + + 'sampled 3000' means sample 3000 rows at random and choose the + missingness to be the same as one of those 3000 at random. + 'sampled' means the same, but with a default number of rows sampled (1000). + """ + if len(self.table_entries) <= self.table_index: + self.print("Error! not on a table") + return + entry = self.table_entries[self.table_index] + if arg == "": + count = 1000 + elif arg.isdecimal(): + count = int(arg) + else: + self.print( + ( + "Error: sampled can be used alone or with" + " an integer argument. {0} is not permitted" + ), + arg, + ) + return + self._set_type( + MissingnessType.SAMPLED, + MissingnessType.sampled_query( + entry.name, + count, + self.get_nonnull_columns(entry.name), + ), + ( + "The missingness patterns and how often they appear in a" + f" sample of {count} from table {entry.name}" + ), + ) + self.print("Table {} set to sampled missingness", self.table_name()) + self.next_table() + + def do_none(self, _arg: str) -> None: + """Set the current table to have no missingness, and go to the next table.""" + self._set_none() + self.print("Table {} set to have no missingness", self.table_name()) + self.next_table() diff --git a/datafaker/interactive/table.py b/datafaker/interactive/table.py new file mode 100644 index 00000000..c23bfdf2 --- /dev/null +++ b/datafaker/interactive/table.py @@ -0,0 +1,376 @@ +"""Table configuration command shell.""" +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass + +import sqlalchemy +from sqlalchemy import MetaData +from typing import Any, cast + +from datafaker.interactive.base import TableType, DbCmd, TableEntry, TYPE_LETTER, TYPE_PROMPT + +@dataclass +class TableCmdTableEntry(TableEntry): + """Table entry for the table command shell.""" + + old_type: TableType + new_type: TableType + + +class TableCmd(DbCmd): + """Command shell allowing the user to set the type of each table.""" + + intro = ( + "Interactive table configuration (ignore," + " vocabulary, private, generate or empty). Type ? for help.\n" + ) + doc_leader = """Use the commands 'ignore', 'vocabulary', +'private', 'empty' or 'generate' to set the table's type. Use 'next' or +'previous' to change table. Use 'tables' and 'columns' for +information about the database. Use 'data', 'peek', 'select' or +'count' to see some data contained in the current table. Use 'quit' +to exit this program.""" + prompt = "(tableconf) " + file = None + WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( + "Vocabulary table {0} references non-vocabulary table {1}" + ) + WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( + "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + ) + WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" + WARNING_TEXT_POTENTIAL_PROBLEMS = ( + "NOTE: The following table types might cause problems later:" + ) + NOTE_TEXT_NO_CHANGES = "You have made no changes." + NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableCmdTableEntry | None: + """ + Make a table entry for the named table. + + :param name: The name of the table. + :param table: The part of ``config.yaml`` corresponding to this table. + :return: The newly-constructed table entry. + """ + if table_config.get("ignore", False): + return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) + if table_config.get("vocabulary_table", False): + return TableCmdTableEntry( + table_name, TableType.VOCABULARY, TableType.VOCABULARY + ) + if table_config.get("primary_private", False): + return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) + if table_config.get("num_rows_per_pass", 1) == 0: + return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) + return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ) -> None: + """Initialise a TableCmd.""" + super().__init__(src_dsn, src_schema, metadata, config) + self.set_prompt() + + @property + def table_entries(self) -> list[TableCmdTableEntry]: + """Get the list of table entries.""" + return cast(list[TableCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + """Get the table entry of the table with the given name.""" + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(TableCmdTableEntry, entry) + + def set_prompt(self) -> None: + """Set the prompt according to the current table and its type.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) + else: + self.prompt = "(table) " + + def set_type(self, t_type: TableType) -> None: + """Set the type of the current table.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + entry.new_type = t_type + + def _copy_entries(self) -> None: + """Alter the configuration to match the new table entries.""" + for entry in self.table_entries: + if entry.old_type != entry.new_type: + table = self.get_table_config(entry.name) + if ( + entry.old_type == TableType.EMPTY + and table.get("num_rows_per_pass", 1) == 0 + ): + table["num_rows_per_pass"] = 1 + if entry.new_type == TableType.IGNORE: + table["ignore"] = True + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + elif entry.new_type == TableType.VOCABULARY: + table.pop("ignore", None) + table["vocabulary_table"] = True + table.pop("primary_private", None) + elif entry.new_type == TableType.PRIVATE: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table["primary_private"] = True + elif entry.new_type == TableType.EMPTY: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + table["num_rows_per_pass"] = 0 + else: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + self.set_table_config(entry.name, table) + + def _get_referenced_tables(self, from_table_name: str) -> set[str]: + """Get all the tables referenced by this table's foreign keys.""" + from_meta = self.metadata.tables[from_table_name] + return { + fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys + } + + def _sanity_check_failures(self) -> list[tuple[str, str, str]]: + """Find tables that reference each other that should not given their types.""" + failures = [] + for from_entry in self.table_entries: + from_t = from_entry.new_type + if from_t == TableType.VOCABULARY: + referenced = self._get_referenced_tables(from_entry.name) + for ref in referenced: + to_entry = self._find_entry_by_table_name(ref) + if ( + to_entry is not None + and to_entry.new_type != TableType.VOCABULARY + ): + failures.append( + ( + self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + from_entry.name, + to_entry.name, + ) + ) + return failures + + def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: + """Find tables that reference each other that might cause problems given their types.""" + warnings = [] + for from_entry in self.table_entries: + from_t = from_entry.new_type + if from_t in {TableType.GENERATE, TableType.PRIVATE}: + referenced = self._get_referenced_tables(from_entry.name) + for ref in referenced: + to_entry = self._find_entry_by_table_name(ref) + if to_entry is not None and to_entry.new_type in { + TableType.EMPTY, + TableType.IGNORE, + }: + warnings.append( + ( + self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + from_entry.name, + to_entry.name, + ) + ) + return warnings + + def do_quit(self, _arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + if entry.old_type != entry.new_type: + count += 1 + self.print( + self.NOTE_TEXT_CHANGING, + entry.name, + entry.old_type.value, + entry.new_type.value, + ) + if count == 0: + self.print(self.NOTE_TEXT_NO_CHANGES) + failures = self._sanity_check_failures() + if failures: + self.print(self.WARNING_TEXT_PROBLEMS_EXIST) + for text, from_t, to_t in failures: + self.print(text, from_t, to_t) + warnings = self._sanity_check_warnings() + if warnings: + self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) + for text, from_t, to_t in warnings: + self.print(text, from_t, to_t) + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables with their types.""" + for entry in self.table_entries: + old = entry.old_type + new = entry.new_type + becomes = " " if old == new else "->" + TYPE_LETTER[new] + self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) + + def do_next(self, arg: str) -> None: + """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" + if arg: + # Find the index of the table called _arg, if any + index = self.find_entry_index_by_table_name(arg) + if index is None: + self.print(self.ERROR_NO_SUCH_TABLE, arg) + return + self._set_table_index(index) + return + self.next_table(self.INFO_NO_MORE_TABLES) + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Get the completions for tables and columns.""" + return [ + entry.name for entry in self.table_entries if entry.name.startswith(text) + ] + + def do_previous(self, _arg: str) -> None: + """Go to the previous table.""" + if not self._set_table_index(self.table_index - 1): + self.print(self.ERROR_ALREADY_AT_START) + + def do_ignore(self, _arg: str) -> None: + """Set the current table as ignored, and go to the next table.""" + self.set_type(TableType.IGNORE) + self.print("Table {} set as ignored", self.table_name()) + self.next_table() + + def do_vocabulary(self, _arg: str) -> None: + """Set the current table as a vocabulary table, and go to the next table.""" + self.set_type(TableType.VOCABULARY) + self.print("Table {} set to be a vocabulary table", self.table_name()) + self.next_table() + + def do_private(self, _arg: str) -> None: + """Set the current table as a primary private table (such as the table of patients).""" + self.set_type(TableType.PRIVATE) + self.print("Table {} set to be a primary private table", self.table_name()) + self.next_table() + + def do_generate(self, _arg: str) -> None: + """Set the current table as to be generated, and go to the next table.""" + self.set_type(TableType.GENERATE) + self.print("Table {} generate", self.table_name()) + self.next_table() + + def do_empty(self, _arg: str) -> None: + """Set the current table as empty; no generators will be run for it.""" + self.set_type(TableType.EMPTY) + self.print("Table {} empty", self.table_name()) + self.next_table() + + def do_columns(self, _arg: str) -> None: + """Report the column names and metadata.""" + self.report_columns() + + def do_data(self, arg: str) -> None: + """ + Report some data. + + 'data' = report a random ten lines, + 'data 20' = report a random 20 lines, + 'data 20 ColumnName' = report a random twenty entries from ColumnName, + 'data 20 ColumnName 30' = report a random twenty entries from + ColumnName of length at least 30, + """ + args = arg.split() + column = None + number = None + arg_index = 0 + min_length = 0 + table_metadata = self.table_metadata() + if arg_index < len(args) and args[arg_index].isdigit(): + number = int(args[arg_index]) + arg_index += 1 + if arg_index < len(args) and args[arg_index] in table_metadata.columns: + column = args[arg_index] + arg_index += 1 + if arg_index < len(args) and args[arg_index].isdigit(): + min_length = int(args[arg_index]) + arg_index += 1 + if arg_index != len(args): + self.print( + """Did not understand these arguments +The format is 'data [entries] [column-name [minimum-length]]' where [] means optional text. +Type 'columns' to find out valid column names for this table. +Type 'help data' for examples.""" + ) + return + if column is None: + if number is None: + number = 10 + self.print_row_data(number) + else: + if number is None: + number = 48 + self.print_column_data(column, number, min_length) + + def complete_data( + self, text: str, line: str, begidx: int, _endidx: int + ) -> list[str]: + """Get completions for arguments to ``data``.""" + previous_parts = line[: begidx - 1].split() + if len(previous_parts) != 2: + return [] + table_metadata = self.table_metadata() + return [k for k in table_metadata.columns.keys() if k.startswith(text)] + + def print_column_data(self, column: str, count: int, min_length: int) -> None: + """ + Print a sample of data from a certain column of the current table. + + :param column: The name of the column to report on. + :param count: The number of rows to sample. + :param min_length: The minimum length of text to choose from (0 for any text). + """ + where = f"WHERE {column} IS NOT NULL" + if 0 < min_length: + where = f"WHERE LENGTH({column}) >= {min_length}" + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT {column} FROM {self.table_name()}" + f" {where} ORDER BY RANDOM() LIMIT {count}" + ) + ) + self.columnize([str(x[0]) for x in result.all()]) + + def print_row_data(self, count: int) -> None: + """ + Print a sample or rows from the current table. + + :param count: The number of rows to report. + """ + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT * FROM {self.table_name()} ORDER BY RANDOM() LIMIT {count}" + ) + ) + if result is None: + self.print("No rows in this table!") + return + self.print_results(result) diff --git a/datafaker/main.py b/datafaker/main.py index cf7bd3bf..22cf0ef5 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -153,13 +153,10 @@ def create_data( config = read_config_file(config_file) if config_file is not None else {} orm_metadata = load_metadata_for_output(orm_file, config) df_module = import_file(df_file) - table_generator_dict = df_module.table_generator_dict - story_generator_list = df_module.story_generator_list try: row_counts = create_db_data( sorted_non_vocabulary_tables(orm_metadata, config), - table_generator_dict, - story_generator_list, + df_module, num_passes, ) logger.debug( diff --git a/datafaker/utils.py b/datafaker/utils.py index 6d0041de..3cc8c282 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -11,11 +11,11 @@ from types import ModuleType from typing import Any, Callable, Final, Generator, Iterable, Optional, TypeVar, Union +import psycopg2 import sqlalchemy import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate -from psycopg2.errors import UndefinedObject from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.exc import IntegrityError, ProgrammingError @@ -79,7 +79,7 @@ def import_file(file_path: str) -> ModuleType: """ spec = importlib.util.spec_from_file_location("df", file_path) if spec is None or spec.loader is None: - raise Exception(f"No loadable module at {file_path}") + raise ImportError(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module @@ -248,7 +248,7 @@ def emit(self, record: Any) -> None: sys.stdout.flush() except RecursionError: raise - except Exception: + except Exception: # pylint: disable=broad-exception-caught self.handleError(record) @@ -275,7 +275,7 @@ def emit(self, record: Any) -> None: sys.stderr.flush() except RecursionError: raise - except Exception: + except Exception: # pylint: disable=broad-exception-caught self.handleError(record) @@ -458,7 +458,8 @@ def remove_vocab_foreign_key_constraints( ) except ProgrammingError as e: session.rollback() - if isinstance(e.orig, UndefinedObject): + # pylint: disable=no-member + if isinstance(e.orig, psycopg2.errors.UndefinedObject): logger.debug("Constraint does not exist") else: raise e diff --git a/tests/test_interactive.py b/tests/test_interactive_generators.py similarity index 70% rename from tests/test_interactive.py rename to tests/test_interactive_generators.py index fcb5ced2..26d08aaf 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive_generators.py @@ -1,453 +1,18 @@ -""" Tests for the base module. """ +""" Tests for the configure-generators command. """ import copy -import random import re +from collections.abc import MutableMapping from dataclasses import dataclass -from typing import Any, Iterable, MutableMapping +from typing import Any, Iterable from unittest import TestCase from unittest.mock import MagicMock, Mock, patch from sqlalchemy import Connection, MetaData, insert, select from datafaker.generators import NullPartitionedNormalGeneratorFactory -from datafaker.interactive import ( - DbCmd, - GeneratorCmd, - MissingnessCmd, - TableCmd, - update_config_generators, -) -from tests.utils import GeneratesDBTestCase, RequiresDBTestCase - - -class TestDbCmdMixin(DbCmd): - """A mixin for capturing output from interactive commands.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize a TestDbCmdMixin""" - super().__init__(*args, **kwargs) - self.reset() - - def reset(self) -> None: - """Reset all the debug messages collected so far.""" - self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] - self.headings: list[str] = [] - self.rows: list[list[str]] = [] - self.column_items: list[list[str]] = [] - self.columns: dict[str, list[Any]] = {} - - def print(self, text: str, *args: Any, **kwargs: Any) -> None: - """Capture the printed message.""" - self.messages.append((text, args, kwargs)) - - def print_table(self, headings: list[str], rows: list[list[str]]) -> None: - """Capture the printed table.""" - self.headings = headings - self.rows = rows - - def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: - """Capture the printed table.""" - self.columns = columns - - # pylint: disable=arguments-renamed - def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: - """Capture the printed table.""" - if items is not None: - self.column_items.append(items) - - def ask_save(self) -> str: - """Quitting always works without needing to ask the user.""" - return "yes" - - -class TestTableCmd(TableCmd, TestDbCmdMixin): - """TableCmd but mocked""" - - -class ConfigureTablesTests(RequiresDBTestCase): - """Testing configure-tables.""" - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: - return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) - - -class ConfigureTablesSrcTests(ConfigureTablesTests): - """Testing configure-tables with src.dump.""" - - dump_file_path = "src.dump" - database_name = "src" - schema_name = "public" - - def test_table_name_prompts(self) -> None: - """Test that the prompts follow the names of the tables.""" - config: MutableMapping[str, Any] = {} - with self._get_cmd(config) as tc: - table_names = list(self.metadata.tables.keys()) - for t in table_names: - self.assertIn(t, tc.prompt) - tc.do_next("") - self.assertListEqual(tc.messages, [(TableCmd.INFO_NO_MORE_TABLES, (), {})]) - tc.reset() - for t in reversed(table_names): - self.assertIn(t, tc.prompt) - tc.do_previous("") - self.assertListEqual( - tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] - ) - tc.reset() - bad_table_name = "notarealtable" - tc.do_next(bad_table_name) - self.assertListEqual( - tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] - ) - tc.reset() - good_table_name = table_names[2] - tc.do_next(good_table_name) - self.assertListEqual(tc.messages, []) - self.assertIn(good_table_name, tc.prompt) - - def test_column_display(self) -> None: - """Test that we can see the names of the columns.""" - config: MutableMapping[str, Any] = {} - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_columns("") - self.assertListEqual( - tc.rows, - [ - ["id", "INTEGER", True, False, ""], - ["a", "BOOLEAN", False, False, ""], - ["b", "BOOLEAN", False, False, ""], - ["c", "TEXT", False, False, ""], - ], - ) - - def test_null_configuration(self) -> None: - """A table still works if its configuration is None.""" - config = { - "tables": None, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_private("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse( - tables["unique_constraint_test"].get("vocabulary_table", False) - ) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue( - tables["unique_constraint_test"].get("primary_private", False) - ) - - def test_null_table_configuration(self) -> None: - """A table still works if its configuration is None.""" - config = { - "tables": { - "unique_constraint_test": None, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_private("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse( - tables["unique_constraint_test"].get("vocabulary_table", False) - ) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue( - tables["unique_constraint_test"].get("primary_private", False) - ) - - def test_configure_tables(self) -> None: - """Test that we can change columns to ignore, vocab or generate.""" - config = { - "tables": { - "unique_constraint_test": { - "vocabulary_table": True, - }, - "no_pk_test": { - "ignore": True, - }, - "hospital_visit": { - "num_passes": 0, - }, - "empty_vocabulary": { - "private": True, - }, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_generate("") - tc.do_next("person") - tc.do_vocabulary("") - tc.do_next("mitigation_type") - tc.do_ignore("") - tc.do_next("hospital_visit") - tc.do_private("") - tc.do_quit("") - tc.do_next("empty_vocabulary") - tc.do_empty("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse( - tables["unique_constraint_test"].get("vocabulary_table", False) - ) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertFalse( - tables["unique_constraint_test"].get("primary_private", False) - ) - self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) - self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) - self.assertTrue(tables["no_pk_test"].get("ignore", False)) - self.assertFalse(tables["no_pk_test"].get("primary_private", False)) - self.assertEqual(tables["no_pk_test"].get("num_rows_per_pass", 1), 1) - self.assertTrue(tables["person"].get("vocabulary_table", False)) - self.assertFalse(tables["person"].get("ignore", False)) - self.assertFalse(tables["person"].get("primary_private", False)) - self.assertEqual(tables["person"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["mitigation_type"].get("vocabulary_table", False)) - self.assertTrue(tables["mitigation_type"].get("ignore", False)) - self.assertFalse(tables["mitigation_type"].get("primary_private", False)) - self.assertEqual(tables["mitigation_type"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["hospital_visit"].get("vocabulary_table", False)) - self.assertFalse(tables["hospital_visit"].get("ignore", False)) - self.assertTrue(tables["hospital_visit"].get("primary_private", False)) - self.assertEqual(tables["hospital_visit"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["empty_vocabulary"].get("vocabulary_table", False)) - self.assertFalse(tables["empty_vocabulary"].get("ignore", False)) - self.assertFalse(tables["empty_vocabulary"].get("primary_private", False)) - self.assertEqual(tables["empty_vocabulary"].get("num_rows_per_pass", 1), 0) - - def test_print_data(self) -> None: - """Test that we can print random rows from the table and random data from columns.""" - person_table = self.metadata.tables["person"] - with self.sync_engine.connect() as conn: - person_rows = conn.execute(select(person_table)).mappings().fetchall() - person_data = {row["person_id"]: row for row in person_rows} - name_set = {row["name"] for row in person_rows} - person_headings = ["person_id", "name", "research_opt_out", "stored_from"] - with self._get_cmd({}) as tc: - tc.do_next("person") - tc.do_data("") - self.assertListEqual(tc.headings, person_headings) - self.assertEqual(len(tc.rows), 10) # default number of rows is 10 - for row in tc.rows: - expected = person_data[row[0]] - self.assertListEqual(row, [expected[h] for h in person_headings]) - tc.reset() - rows_to_get_count = 6 - tc.do_data(str(rows_to_get_count)) - self.assertListEqual(tc.headings, person_headings) - self.assertEqual(len(tc.rows), rows_to_get_count) - for row in tc.rows: - expected = person_data[row[0]] - self.assertListEqual(row, [expected[h] for h in person_headings]) - tc.reset() - to_get_count = 12 - tc.do_data(f"{to_get_count} name") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(len(tc.column_items[0]), to_get_count) - self.assertLessEqual(set(tc.column_items[0]), name_set) - tc.reset() - tc.do_data(f"{to_get_count} name 12") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(len(tc.column_items[0]), to_get_count) - tc.reset() - tc.do_data(f"{to_get_count} name 13") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual( - set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) - ) - tc.reset() - tc.do_data(f"{to_get_count} name 16") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual( - set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) - ) - - def test_list_tables(self) -> None: - """Test that we can list the tables""" - config = { - "tables": { - "unique_constraint_test": { - "vocabulary_table": True, - }, - "no_pk_test": { - "ignore": True, - }, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_ignore("") - tc.do_next("person") - tc.do_vocabulary("") - tc.reset() - tc.do_tables("") - person_listed = False - unique_constraint_test_listed = False - no_pk_test_listed = False - for _text, args, _kwargs in tc.messages: - if args[2] == "person": - self.assertFalse(person_listed) - person_listed = True - self.assertEqual(args[0], "G") - self.assertEqual(args[1], "->V") - elif args[2] == "unique_constraint_test": - self.assertFalse(unique_constraint_test_listed) - unique_constraint_test_listed = True - self.assertEqual(args[0], "V") - self.assertEqual(args[1], "->I") - elif args[2] == "no_pk_test": - self.assertFalse(no_pk_test_listed) - no_pk_test_listed = True - self.assertEqual(args[0], "I") - self.assertEqual(args[1], " ") - else: - self.assertEqual(args[0], "G") - self.assertEqual(args[1], " ") - self.assertTrue(person_listed) - self.assertTrue(unique_constraint_test_listed) - self.assertTrue(no_pk_test_listed) - - -class ConfigureTablesInstrumentsTests(ConfigureTablesTests): - """Testing configure-tables with the instrument.sql database.""" - - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def test_sanity_checks_both(self) -> None: - """ - Test ``configure-tables`` sanity checks. - """ - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with self._get_cmd(config) as tc: - tc.reset() - tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) - self.assertEqual( - tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) - ) - self.assertEqual( - tc.messages[2], - ( - TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - ("model", "manufacturer"), - {}, - ), - ) - self.assertEqual( - tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) - ) - self.assertEqual( - tc.messages[4], - ( - TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - ("signature_model", "player"), - {}, - ), - ) - - def test_sanity_checks_warnings_only(self) -> None: - """ - Test ``configure-tables`` sanity checks. - """ - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: - tc.do_next("manufacturer") - tc.do_vocabulary("") - tc.reset() - tc.do_quit("") - self.assertEqual( - tc.messages[0], - ( - TableCmd.NOTE_TEXT_CHANGING, - ("manufacturer", "ignore", "vocabulary"), - {}, - ), - ) - self.assertEqual( - tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) - ) - self.assertEqual( - tc.messages[2], - ( - TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - ("signature_model", "player"), - {}, - ), - ) - - def test_sanity_checks_errors_only(self) -> None: - """ - Test ``configure-tables`` sanity checks. - """ - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: - tc.do_next("signature_model") - tc.do_empty("") - tc.reset() - tc.do_quit("") - self.assertEqual( - tc.messages[0], - ( - TableCmd.NOTE_TEXT_CHANGING, - ("signature_model", "generate", "empty"), - {}, - ), - ) - self.assertEqual( - tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) - ) - self.assertEqual( - tc.messages[2], - ( - TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - ("model", "manufacturer"), - {}, - ), - ) +from datafaker.interactive import update_config_generators +from datafaker.interactive.generators import GeneratorCmd +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): @@ -1169,97 +734,6 @@ def test_create_with_weighted_choice(self) -> None: self.assertSetEqual(threes, {1, 2, 3, 4, 5}) -class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): - """MissingnessCmd but mocked""" - - -class ConfigureMissingnessTests(RequiresDBTestCase): - """Testing configure-missing.""" - - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: - """We are using configure-missingness.""" - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_set_missingness_to_sampled(self) -> None: - """Test that we can set one table to sampled missingness.""" - with self._get_cmd({}) as mc: - table = "signature_model" - mc.do_next(table) - mc.do_counts("") - self.assertListEqual( - mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] - ) - # Check the counts of NULLs in each column - self.assertListEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) - mc.do_sampled("") - mc.do_quit("") - self.assertListEqual( - mc.config["tables"][table]["missingness_generators"], - [ - { - "columns": ["player_id", "based_on"], - "kwargs": { - "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' - }, - "name": "column_presence.sampled", - } - ], - ) - self.assertEqual( - mc.config["src-stats"][0]["name"], - "missing_auto__signature_model__0", - ) - self.assertEqual( - mc.config["src-stats"][0]["query"], - ( - "SELECT COUNT(*) AS row_count," - " player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null," - " based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000)" - " AS __t GROUP BY player_id__is_null, based_on__is_null" - ), - ) - - -class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): - """Testing configure-missing with generation.""" - - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_missingness(self) -> None: - """Test that we can sample real missingness and reproduce it.""" - random.seed(45) - # Configure the missingness - table_name = "signature_model" - with self._get_cmd({}) as mc: - mc.do_next(table_name) - mc.do_sampled("") - mc.do_quit("") - config = mc.config - self.generate_data(config, num_passes=100) - # Test that each missingness pattern is present in the database - with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).mappings().fetchall() - patterns: set[int] = set() - for row in rows: - p = 0 if row["player_id"] is None else 1 - b = 0 if row["based_on"] is None else 2 - patterns.add(p + b) - # all pattern possibilities should be present - self.assertSetEqual(patterns, {0, 1, 2, 3}) - - class GeneratorTests(GeneratesDBTestCase): """Testing configure-generators with generation.""" @@ -1445,6 +919,7 @@ def covar(self) -> float: return (self.xy - self.x * self.y / self.n) / (self.n - 1) +# pylint disable: too-many-instance-attributes class EavMeasurementTableStats: """The statistics for the Measurement table of eav.sql.""" diff --git a/tests/test_interactive_missingness.py b/tests/test_interactive_missingness.py new file mode 100644 index 00000000..7a63ea52 --- /dev/null +++ b/tests/test_interactive_missingness.py @@ -0,0 +1,100 @@ +""" Tests for the configure-missingness command. """ +import random +from collections.abc import MutableMapping +from typing import Any + +from sqlalchemy import select + +from datafaker.interactive import MissingnessCmd +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin + + +class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): + """MissingnessCmd but mocked""" + + +class ConfigureMissingnessTests(RequiresDBTestCase): + """Testing configure-missing.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + """We are using configure-missingness.""" + return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_set_missingness_to_sampled(self) -> None: + """Test that we can set one table to sampled missingness.""" + with self._get_cmd({}) as mc: + table = "signature_model" + mc.do_next(table) + mc.do_counts("") + self.assertSequenceEqual( + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] + ) + # Check the counts of NULLs in each column + self.assertSequenceEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) + mc.do_sampled("") + mc.do_quit("") + self.assertListEqual( + mc.config["tables"][table]["missingness_generators"], + [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' + }, + "name": "column_presence.sampled", + } + ], + ) + self.assertEqual( + mc.config["src-stats"][0]["name"], + "missing_auto__signature_model__0", + ) + self.assertEqual( + mc.config["src-stats"][0]["query"], + ( + "SELECT COUNT(*) AS row_count," + " player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null," + " based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000)" + " AS __t GROUP BY player_id__is_null, based_on__is_null" + ), + ) + + +class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): + """Testing configure-missing with generation.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_create_with_missingness(self) -> None: + """Test that we can sample real missingness and reproduce it.""" + random.seed(45) + # Configure the missingness + table_name = "signature_model" + with self._get_cmd({}) as mc: + mc.do_next(table_name) + mc.do_sampled("") + mc.do_quit("") + config = mc.config + self.generate_data(config, num_passes=100) + # Test that each missingness pattern is present in the database + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).mappings().fetchall() + patterns: set[int] = set() + for row in rows: + p = 0 if row["player_id"] is None else 1 + b = 0 if row["based_on"] is None else 2 + patterns.add(p + b) + # all pattern possibilities should be present + self.assertSetEqual(patterns, {0, 1, 2, 3}) diff --git a/tests/test_interactive_table.py b/tests/test_interactive_table.py new file mode 100644 index 00000000..04b157e7 --- /dev/null +++ b/tests/test_interactive_table.py @@ -0,0 +1,398 @@ +""" Tests for the configure-tables command. """ +from collections.abc import MutableMapping +from typing import Any + +from sqlalchemy import select + +from datafaker.interactive import TableCmd +from tests.utils import RequiresDBTestCase, TestDbCmdMixin + + +class TestTableCmd(TableCmd, TestDbCmdMixin): + """TableCmd but mocked""" + + +class ConfigureTablesTests(RequiresDBTestCase): + """Testing configure-tables.""" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: + return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) + + +class ConfigureTablesSrcTests(ConfigureTablesTests): + """Testing configure-tables with src.dump.""" + + dump_file_path = "src.dump" + database_name = "src" + schema_name = "public" + + def test_table_name_prompts(self) -> None: + """Test that the prompts follow the names of the tables.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as tc: + table_names = list(self.metadata.tables.keys()) + for t in table_names: + self.assertIn(t, tc.prompt) + tc.do_next("") + self.assertListEqual(tc.messages, [(TableCmd.INFO_NO_MORE_TABLES, (), {})]) + tc.reset() + for t in reversed(table_names): + self.assertIn(t, tc.prompt) + tc.do_previous("") + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] + ) + tc.reset() + bad_table_name = "notarealtable" + tc.do_next(bad_table_name) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] + ) + tc.reset() + good_table_name = table_names[2] + tc.do_next(good_table_name) + self.assertSequenceEqual(tc.messages, []) + self.assertIn(good_table_name, tc.prompt) + + def test_column_display(self) -> None: + """Test that we can see the names of the columns.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_columns("") + self.assertSequenceEqual( + tc.rows, + [ + ["id", "INTEGER", True, False, ""], + ["a", "BOOLEAN", False, False, ""], + ["b", "BOOLEAN", False, False, ""], + ["c", "TEXT", False, False, ""], + ], + ) + + def test_null_configuration(self) -> None: + """A table still works if its configuration is None.""" + config = { + "tables": None, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_private("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) + + def test_null_table_configuration(self) -> None: + """A table still works if its configuration is None.""" + config = { + "tables": { + "unique_constraint_test": None, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_private("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) + + def test_configure_tables(self) -> None: + """Test that we can change columns to ignore, vocab or generate.""" + config = { + "tables": { + "unique_constraint_test": { + "vocabulary_table": True, + }, + "no_pk_test": { + "ignore": True, + }, + "hospital_visit": { + "num_passes": 0, + }, + "empty_vocabulary": { + "private": True, + }, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_generate("") + tc.do_next("person") + tc.do_vocabulary("") + tc.do_next("mitigation_type") + tc.do_ignore("") + tc.do_next("hospital_visit") + tc.do_private("") + tc.do_quit("") + tc.do_next("empty_vocabulary") + tc.do_empty("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertFalse( + tables["unique_constraint_test"].get("primary_private", False) + ) + self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) + self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) + self.assertTrue(tables["no_pk_test"].get("ignore", False)) + self.assertFalse(tables["no_pk_test"].get("primary_private", False)) + self.assertEqual(tables["no_pk_test"].get("num_rows_per_pass", 1), 1) + self.assertTrue(tables["person"].get("vocabulary_table", False)) + self.assertFalse(tables["person"].get("ignore", False)) + self.assertFalse(tables["person"].get("primary_private", False)) + self.assertEqual(tables["person"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["mitigation_type"].get("vocabulary_table", False)) + self.assertTrue(tables["mitigation_type"].get("ignore", False)) + self.assertFalse(tables["mitigation_type"].get("primary_private", False)) + self.assertEqual(tables["mitigation_type"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["hospital_visit"].get("vocabulary_table", False)) + self.assertFalse(tables["hospital_visit"].get("ignore", False)) + self.assertTrue(tables["hospital_visit"].get("primary_private", False)) + self.assertEqual(tables["hospital_visit"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["empty_vocabulary"].get("vocabulary_table", False)) + self.assertFalse(tables["empty_vocabulary"].get("ignore", False)) + self.assertFalse(tables["empty_vocabulary"].get("primary_private", False)) + self.assertEqual(tables["empty_vocabulary"].get("num_rows_per_pass", 1), 0) + + def test_print_data(self) -> None: + """Test that we can print random rows from the table and random data from columns.""" + person_table = self.metadata.tables["person"] + with self.sync_engine.connect() as conn: + person_rows = conn.execute(select(person_table)).mappings().fetchall() + person_data = {row["person_id"]: row for row in person_rows} + name_set = {row["name"] for row in person_rows} + person_headings = ["person_id", "name", "research_opt_out", "stored_from"] + with self._get_cmd({}) as tc: + tc.do_next("person") + tc.do_data("") + self.assertSequenceEqual(tc.headings, person_headings) + self.assertEqual(len(tc.rows), 10) # default number of rows is 10 + for row in tc.rows: + expected = person_data[row[0]] + self.assertSequenceEqual(row, [expected[h] for h in person_headings]) + tc.reset() + rows_to_get_count = 6 + tc.do_data(str(rows_to_get_count)) + self.assertSequenceEqual(tc.headings, person_headings) + self.assertEqual(len(tc.rows), rows_to_get_count) + for row in tc.rows: + expected = person_data[row[0]] + self.assertSequenceEqual(row, [expected[h] for h in person_headings]) + tc.reset() + to_get_count = 12 + tc.do_data(f"{to_get_count} name") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual(len(tc.column_items[0]), to_get_count) + self.assertLessEqual(set(tc.column_items[0]), name_set) + tc.reset() + tc.do_data(f"{to_get_count} name 12") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual(len(tc.column_items[0]), to_get_count) + tc.reset() + tc.do_data(f"{to_get_count} name 13") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) + ) + tc.reset() + tc.do_data(f"{to_get_count} name 16") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) + ) + + def test_list_tables(self) -> None: + """Test that we can list the tables""" + config = { + "tables": { + "unique_constraint_test": { + "vocabulary_table": True, + }, + "no_pk_test": { + "ignore": True, + }, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_ignore("") + tc.do_next("person") + tc.do_vocabulary("") + tc.reset() + tc.do_tables("") + person_listed = False + unique_constraint_test_listed = False + no_pk_test_listed = False + for _text, args, _kwargs in tc.messages: + if args[2] == "person": + self.assertFalse(person_listed) + person_listed = True + self.assertEqual(args[0], "G") + self.assertEqual(args[1], "->V") + elif args[2] == "unique_constraint_test": + self.assertFalse(unique_constraint_test_listed) + unique_constraint_test_listed = True + self.assertEqual(args[0], "V") + self.assertEqual(args[1], "->I") + elif args[2] == "no_pk_test": + self.assertFalse(no_pk_test_listed) + no_pk_test_listed = True + self.assertEqual(args[0], "I") + self.assertEqual(args[1], " ") + else: + self.assertEqual(args[0], "G") + self.assertEqual(args[1], " ") + self.assertTrue(person_listed) + self.assertTrue(unique_constraint_test_listed) + self.assertTrue(no_pk_test_listed) + + +class ConfigureTablesInstrumentsTests(ConfigureTablesTests): + """Testing configure-tables with the instrument.sql database.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def test_sanity_checks_both(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with self._get_cmd(config) as tc: + tc.reset() + tc.do_quit("") + self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) + self.assertEqual( + tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[4], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) + + def test_sanity_checks_warnings_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + tc.do_next("manufacturer") + tc.do_vocabulary("") + tc.reset() + tc.do_quit("") + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("manufacturer", "ignore", "vocabulary"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) + + def test_sanity_checks_errors_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + tc.do_next("signature_model") + tc.do_empty("") + tc.reset() + tc.do_quit("") + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("signature_model", "generate", "empty"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) diff --git a/tests/test_main.py b/tests/test_main.py index e318f1e5..1d5f59d8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -398,9 +398,7 @@ def test_make_stats( self.assertSuccess(result) with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - mock_make.assert_called_once_with( - get_test_settings().src_dsn, config, None - ) + mock_make.assert_called_once_with(get_test_settings().src_dsn, config, None) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" ) diff --git a/tests/utils.py b/tests/utils.py index 78df87cc..ab6f1d23 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import shutil import traceback +from collections.abc import MutableSequence, Sequence from functools import lru_cache from pathlib import Path from subprocess import run @@ -16,6 +17,7 @@ from datafaker import settings from datafaker.create import create_db_data_into +from datafaker.interactive.base import DbCmd from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from from datafaker.utils import ( @@ -264,12 +266,9 @@ def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) - table_generator_dict = datafaker_module.table_generator_dict - story_generator_list = datafaker_module.story_generator_list create_db_data_into( sorted_non_vocabulary_tables(self.metadata, config), - table_generator_dict, - story_generator_list, + datafaker_module, num_passes, self.dsn, self.schema_name, @@ -288,3 +287,45 @@ def generate_data( self.remove_data(config) self.create_data(config, num_passes) return src_stats + + +class TestDbCmdMixin(DbCmd): + """A mixin for capturing output from interactive commands.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a TestDbCmdMixin""" + super().__init__(*args, **kwargs) + self.reset() + + def reset(self) -> None: + """Reset all the debug messages collected so far.""" + self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] + self.headings: Sequence[str] = [] + self.rows: Sequence[Sequence[str]] = [] + self.column_items: MutableSequence[Sequence[str]] = [] + self.columns: Mapping[str, Sequence[Any]] = {} + + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Capture the printed message.""" + self.messages.append((text, args, kwargs)) + + def print_table( + self, headings: Sequence[str], rows: Sequence[Sequence[str]] + ) -> None: + """Capture the printed table.""" + self.headings = headings + self.rows = rows + + def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: + """Capture the printed table.""" + self.columns = columns + + # pylint: disable=arguments-renamed + def columnize(self, items: Sequence[str] | None, _displaywidth: int = 80) -> None: + """Capture the printed table.""" + if items is not None: + self.column_items.append(items) + + def ask_save(self) -> str: + """Quitting always works without needing to ask the user.""" + return "yes" From 2a4982f9ba9dca563e69bdb9a4f836c1c996d743 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 15 Oct 2025 18:15:12 +0100 Subject: [PATCH 21/44] Pre-commit cleaned. --- datafaker/generators/base.py | 7 +- datafaker/generators/choice.py | 1 + datafaker/generators/continuous.py | 106 ++- datafaker/generators/mimesis.py | 11 +- datafaker/generators/partitioned.py | 289 +++---- datafaker/interactive/__init__.py | 6 +- datafaker/interactive/base.py | 26 +- datafaker/interactive/generators.py | 90 +-- datafaker/interactive/table.py | 2 +- datafaker/make.py | 22 +- datafaker/utils.py | 60 +- tests/test_interactive_generators.py | 708 ++---------------- ...test_interactive_generators_partitioned.py | 419 +++++++++++ tests/test_noninteractive_generators.py | 179 +++++ 14 files changed, 1025 insertions(+), 901 deletions(-) create mode 100644 tests/test_interactive_generators_partitioned.py create mode 100644 tests/test_noninteractive_generators.py diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py index 1adcb9ca..f2a1459f 100644 --- a/datafaker/generators/base.py +++ b/datafaker/generators/base.py @@ -13,7 +13,7 @@ from typing_extensions import Self from datafaker.base import DistributionGenerator -from datafaker.utils import T, logger +from datafaker.utils import logger NumericType = Union[int, float] @@ -22,6 +22,10 @@ generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) +class GeneratorError(Exception): + """Error thrown from Datafaker Generators.""" + + class Generator(ABC): """ Random data generator. @@ -264,6 +268,7 @@ class Buckets: the fit of generators against it. """ + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, engine: Engine, diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py index 140b6860..54f69d3e 100644 --- a/datafaker/generators/choice.py +++ b/datafaker/generators/choice.py @@ -49,6 +49,7 @@ class ChoiceGenerator(Generator): STORE_COUNTS = False + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, table_name: str, diff --git a/datafaker/generators/continuous.py b/datafaker/generators/continuous.py index 42e4bcd4..fc50c7fe 100644 --- a/datafaker/generators/continuous.py +++ b/datafaker/generators/continuous.py @@ -1,8 +1,11 @@ """Generator factories for making generators of continuous distributions.""" -from typing import Any, Sequence +import itertools +from collections.abc import Iterable, Sequence +from typing import Any from sqlalchemy import Column, Engine, RowMapping, text +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.types import Integer, Numeric from datafaker.generators.base import ( @@ -13,7 +16,7 @@ dist_gen, get_column_type, ) -from datafaker.utils import logger +from datafaker.utils import Empty, logger class ContinuousDistributionGenerator(Generator): @@ -166,20 +169,26 @@ def get_generators( class LogNormalGenerator(Generator): """Generator producing numbers in a log-normal distribution.""" - # TODO: figure out the real buckets here (this was from a random sample in R) + # R: + # > xs<-seq(-2,2,0.5)*sqrt((exp(1)-1)*exp(1))+exp(0.5) + # > ys <- plnorm(xs) + # > c(ys, 1) - c(0,ys) + # [1] 0.00000000 0.00000000 0.00000000 0.28589471 0.40556775 0.15086088 + # [7] 0.06716451 0.03428958 0.01924848 0.03697409 expected_buckets = [ 0, 0, 0, - 0.28627, - 0.40607, - 0.14937, - 0.06735, - 0.03492, - 0.01918, - 0.03684, + 0.28589471, + 0.40556775, + 0.15086088, + 0.06716451, + 0.03428958, + 0.01924848, + 0.03697409, ] + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, table_name: str, @@ -290,6 +299,7 @@ def _get_generators_from_buckets( class MultivariateNormalGenerator(Generator): """Generator of multiple values drawn from a multivariate normal distribution.""" + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, table_name: str, @@ -359,11 +369,12 @@ def query_var(self, column: str) -> str: """Get the SQL expression of the value to query for this column.""" return column + # pylint: disable=too-many-arguments too-many-positional-arguments def query( self, table: str, - columns: list[Column], - predicates: list[str] = [], + columns: Sequence[Column], + predicates: Iterable[str] = Empty.iterable(), group_by_clause: str = "", constant_clauses: str = "", constants: str = "", @@ -385,17 +396,6 @@ def query( :param suppress_count: a group smaller than this will be suppressed. :param sample_count: this many samples will be taken from each partition. """ - preds = [self.query_predicate(col) for col in columns] + predicates - where = " WHERE " + " AND ".join(preds) if preds else "" - avgs = "".join( - f", AVG({self.query_var(col.name)}) AS m{i}" - for i, col in enumerate(columns) - ) - multiples = "".join( - f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" - for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[: iy + 1]) - ) means = "".join(f", _q.m{i}" for i in range(len(columns))) covs = "".join( ( @@ -405,20 +405,58 @@ def query( for iy in range(len(columns)) for ix in range(iy + 1) ) - if sample_count is None: - subquery = table + where - else: - subquery = ( - f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" - f" LIMIT {sample_count}) AS _sampled" - ) - # if there are any numeric columns we need at least# + subquery = self._inner_query(table, columns, predicates, sample_count) + # if there are any numeric columns we need at least # two rows to make any (co)variances at all suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" return ( f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" - f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" - f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" + f" FROM ({self._middle_query(columns, constants, subquery, group_by_clause)})" + f" AS _q{suppress_clause}" + ) + + def _inner_query( + self, + table: str, + columns: Sequence[Column], + predicates: Iterable[str], + sample_count: int | None, + ) -> str: + """Get the rows from the table that we are interested in.""" + preds = itertools.chain( + (self.query_predicate(col) for col in columns), + predicates, + ) + where = " AND ".join(preds) if preds else "" + if where: + where = " WHERE " + where + if sample_count is None: + return table + where + return ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) + + def _middle_query( + self, + columns: Sequence[Column], + constants: str, + inner_query: str, + group_by_clause: str, + ) -> str: + """Get the basic statistics (and constants) from the inner query.""" + multiples = "".join( + f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" + for iy, coly in enumerate(columns) + for ix, colx in enumerate(columns[: iy + 1]) + ) + avgs = "".join( + f", AVG({self.query_var(col.name)}) AS m{i}" + for i, col in enumerate(columns) + ) + return ( + f"SELECT COUNT(*) AS count{multiples}{avgs}{constants}" + f" FROM {inner_query}{group_by_clause}" ) def get_generators( @@ -439,7 +477,7 @@ def get_generators( with engine.connect() as connection: try: covariates = connection.execute(text(query)).mappings().first() - except Exception as e: + except SQLAlchemyError as e: logger.debug("SQL query %s failed with error %s", query, e) return [] if not covariates or covariates["c0_0"] is None: diff --git a/datafaker/generators/mimesis.py b/datafaker/generators/mimesis.py index 65c5d98d..b3003359 100644 --- a/datafaker/generators/mimesis.py +++ b/datafaker/generators/mimesis.py @@ -5,12 +5,14 @@ import mimesis import mimesis.locales from sqlalchemy import Column, Engine, text +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time -from datafaker.base import DistributionGenerator from datafaker.generators.base import ( Buckets, + DistributionGenerator, Generator, + GeneratorError, GeneratorFactory, get_column_type, ) @@ -41,12 +43,12 @@ def __init__( f = generic for part in function_name.split("."): if not hasattr(f, part): - raise Exception( + raise GeneratorError( f"Mimesis does not have a function {function_name}: {part} not found" ) f = getattr(f, part) if not callable(f): - raise Exception( + raise GeneratorError( f"Mimesis object {function_name} is not a callable," " so cannot be used as a generator" ) @@ -152,6 +154,7 @@ def generate_data(self, count: int) -> list[Any]: class MimesisDateTimeGenerator(MimesisGeneratorBase): """DateTime generator using Mimesis.""" + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, column: Column, @@ -306,7 +309,7 @@ def get_generators( f"LENGTH({column.name})", ) fitness_fn = len - except Exception: + except SQLAlchemyError: # Some column types that appear to be strings (such as enums) # cannot have their lengths measured. In this case we cannot # detect fitness using lengths. diff --git a/datafaker/generators/partitioned.py b/datafaker/generators/partitioned.py index 93a7880f..f14af736 100644 --- a/datafaker/generators/partitioned.py +++ b/datafaker/generators/partitioned.py @@ -80,6 +80,43 @@ def comment(self) -> str: ) +@dataclass +class NullableColumn: + """A reference to a nullable column whose nullability is part of a partitioning.""" + + column: Column + # The bit (power of two) of the number of the partition in the partition sizes list + bitmask: int + + +class PartitionCountQuery: + """Query, result and comment for the row counts of the null pattern partitions.""" + + def __init__( + self, + connection: Connection, + query: str, + table_name: str, + nullable_columns: Iterable[NullableColumn], + ) -> None: + """ + Initialise the partition count query. + + :param connection: Database connection. + :param query: The query getting the row counts of the null pattern partitions. + :param table_name: The name of the table being queried. + :param nullable_columns: The columns that are being checked for nullness. + """ + self.query = query + rows = connection.execute(text(query)).mappings().fetchall() + self.results = [dict(row) for row in rows] + self.comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table_name} being null" + ) + + class NullPartitionedNormalGenerator(Generator): """ A generator of mixed numeric and non-numeric data. @@ -97,23 +134,20 @@ class NullPartitionedNormalGenerator(Generator): rows are used because no covariate matrix is required for this). """ + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, query_name: str, partitions: dict[int, RowPartition], function_name: str = "grouped_multivariate_lognormal", name_suffix: str | None = None, - partition_count_query: str | None = None, - partition_counts: Iterable[RowMapping] = [], - partition_count_comment: str | None = None, + partition_count_query: PartitionCountQuery | None = None, ): """Initialise a NullPartitionedNormalGenerator.""" self._query_name = query_name self._partitions = partitions self._function_name = function_name self._partition_count_query = partition_count_query - self._partition_counts = [dict(pc) for pc in partition_counts] - self._partition_count_comment = partition_count_comment if name_suffix: self._name = f"null-partitioned {function_name} [{name_suffix}]" else: @@ -186,8 +220,8 @@ def custom_queries(self) -> dict[str, Any]: return partitions return { self._count_query_name(): { - "comment": self._partition_count_comment, - "query": self._partition_count_query, + "comment": self._partition_count_query.comment, + "query": self._partition_count_query.query, }, **partitions, } @@ -223,12 +257,16 @@ def _actual_kwargs_with_combinations( def actual_kwargs(self) -> dict[str, Any]: """Get the kwargs (summary statistics) this generator was instantiated with.""" + if self._partition_count_query is None: + counts = None + else: + counts = self._partition_count_query.results return { "alternative_configs": [ self._actual_kwargs_with_combinations(self._partitions[index]) for index in range(len(self._partitions)) ], - "counts": self._partition_counts, + "counts": counts, } def generate_data(self, count: int) -> list[Any]: @@ -252,15 +290,7 @@ def powerset(xs: list[T]) -> Iterable[Iterable[T]]: return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) -@dataclass -class NullableColumn: - """A reference to a nullable column whose nullability is part of a partitioning.""" - - column: Column - # The bit (power of two) of the number of the partition in the partition sizes list - bitmask: int - - +# pylint: disable=too-many-instance-attributes class NullPatternPartition: """Get the definition of a partition (in other words, what makes it not another partition).""" @@ -362,6 +392,70 @@ def get_partition_count_query( f' FROM {table} GROUP BY "index") AS _q {where}' ) + def _get_row_partition( + self, + table: str, + partition: NullPatternPartition, + suppress_count: int = 1, + sample_count: int | None = None, + ) -> RowPartition: + """Get the RowPartition from a NullPatternPartition.""" + query = self.query( + table=table, + columns=partition.included_numeric, + predicates=partition.predicates, + group_by_clause=partition.group_by_clause, + constants=partition.constants, + constant_clauses=partition.constant_clauses, + suppress_count=suppress_count, + sample_count=sample_count, + ) + return RowPartition( + query, + partition.included_numeric, + partition.included_choice, + partition.excluded, + partition.nones, + [], + ) + + # pylint: disable=too-many-arguments too-many-positional-arguments + def _get_generator( + self, + connection: Connection, + table_name: str, + columns: list[Column], + nullable_columns: list[NullableColumn], + where: str | None = None, + name_suffix: str | None = None, + suppress_count: int = 1, + sample_count: int | None = None, + ) -> NullPartitionedNormalGenerator | None: + query = self.get_partition_count_query(nullable_columns, table_name, where) + partitions: dict[int, RowPartition] = {} + for partition_nonnulls in powerset(nullable_columns): + partition_def = NullPatternPartition(columns, partition_nonnulls) + partitions[partition_def.index] = self._get_row_partition( + table_name, + partition_def, + suppress_count=suppress_count, + sample_count=sample_count, + ) + if not self._execute_partition_queries(connection, partitions): + return None + return NullPartitionedNormalGenerator( + f"{table_name}__{columns[0].name}", + partitions, + self.function_name(), + name_suffix=name_suffix, + partition_count_query=PartitionCountQuery( + connection, + query, + table_name, + nullable_columns, + ), + ) + def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: @@ -372,139 +466,54 @@ def get_generators( if not nullable_columns: return [] table = columns[0].table.name - query_name = f"{table}__{columns[0].name}" - # Partitions for minimal suppression and no sampling - row_partitions_maximal: dict[int, RowPartition] = {} - # Partitions for minimal suppression but sampling - row_partitions_sampled: dict[int, RowPartition] = {} - # Partitions for normal suppression and severe sampling - row_partitions_ss: dict[int, RowPartition] = {} - for partition_nonnulls in powerset(nullable_columns): - partition_def = NullPatternPartition(columns, partition_nonnulls) - query_all = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - ) - row_partitions_maximal[partition_def.index] = RowPartition( - query_all, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - query_sampled = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_sampled[partition_def.index] = RowPartition( - query_sampled, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - {}, - ) - query_ss = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - suppress_count=self.SUPPRESS_COUNT, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_ss[partition_def.index] = RowPartition( - query_ss, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - gens: list[Generator] = [] + gens: list[Generator | None] = [] try: with engine.connect() as connection: - partition_query_max = self.get_partition_count_query( - nullable_columns, table - ) - partition_count_max_results = ( - connection.execute(text(partition_query_max)).mappings().fetchall() - ) - count_comment = ( - "Number of rows for each combination of the columns" - f" { {nc.column.name for nc in nullable_columns} }" - f" of the table {table} being null" - ) - if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) - ) - if self._execute_partition_queries(connection, row_partitions_sampled): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_sampled, - self.function_name(), - name_suffix="sampled", - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, ) - if self._execute_partition_queries(connection, row_partitions_sampled): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_sampled, - self.function_name(), - name_suffix="sampled", - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) + ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + name_suffix="sampled", + sample_count=self.SAMPLE_COUNT, ) - partition_query_ss = self.get_partition_count_query( - nullable_columns, - table, - where=f"WHERE {self.SUPPRESS_COUNT} < count", ) - partition_count_ss_results = ( - connection.execute(text(partition_query_ss)).mappings().fetchall() + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + name_suffix="sampled and suppressed", + suppress_count=self.SUPPRESS_COUNT, + sample_count=self.SAMPLE_COUNT, + ) ) - if self._execute_partition_queries(connection, row_partitions_ss): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + name_suffix="suppressed", + suppress_count=self.SUPPRESS_COUNT, ) + ) except sqlalchemy.exc.DatabaseError as exc: logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) return [] - return gens + return [gen for gen in gens if gen] def _execute_partition_queries( self, diff --git a/datafaker/interactive/__init__.py b/datafaker/interactive/__init__.py index 952eadf4..c279720f 100644 --- a/datafaker/interactive/__init__.py +++ b/datafaker/interactive/__init__.py @@ -20,7 +20,7 @@ if not hasattr(readline, "backend"): setattr(readline, "backend", "readline") -except: +except ImportError: pass @@ -86,7 +86,7 @@ def update_config_generators( if line: if len(line) < 3: logger.error( - "line {0} of file {1} has fewer than three values", + "line %d of file %s has fewer than three values", line_no, spec_path, ) @@ -95,6 +95,6 @@ def update_config_generators( if len(cols) == 1 or gc.set_merged_columns(cols[0], cols[1]): try_setting_generator(gc, itertools.islice(line, 2, None)) else: - logger.warning("no such column {0}[{1}]", line[0], line[1]) + logger.warning("no such column %s[%s]", line[0], line[1]) gc.do_quit("yes") return gc.config diff --git a/datafaker/interactive/base.py b/datafaker/interactive/base.py index 51793fe0..9d612a7c 100644 --- a/datafaker/interactive/base.py +++ b/datafaker/interactive/base.py @@ -100,7 +100,10 @@ class DbCmd(ABC, cmd.Cmd): INFO_NO_MORE_TABLES = "There are no more tables" ERROR_ALREADY_AT_START = "Error: Already at the start" ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" - ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" + ERROR_NO_SUCH_TABLE_OR_COLUMN = ( + "Error: '{0}' is not the name of a table" + " in this database or a column in this table" + ) ROW_COUNT_MSG = "Total row count: {}" @abstractmethod @@ -185,7 +188,7 @@ def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: :param columns: Dict of column names to the values in the column. """ output = PrettyTable() - row_count = max([len(col) for col in columns.values()]) + row_count = max(len(col) for col in columns.values()) for field_name, data in columns.items(): output.add_column(field_name, list(data) + [None] * (row_count - len(data))) print(output) @@ -207,7 +210,6 @@ def ask_save(self) -> str: @abstractmethod def set_prompt(self) -> None: """Set the prompt according to the current state.""" - ... def _set_table_index(self, index: int) -> bool: """ @@ -325,21 +327,25 @@ def do_counts(self, _arg: str) -> None: nonnull_columns = self.get_nonnull_columns(table_name) colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns] with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + result = ( + connection.execute( + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + ) ) - ).first() + .mappings() + .first() + ) if result is None: self.print("Could not count rows in table {0}", table_name) return - row_count = result.row_count + row_count = result.get("row_count", 0) self.print(self.ROW_COUNT_MSG, row_count) self.print_table( ["Column", "NULL count"], [ [name, row_count - count] - for name, count in result._mapping.items() + for name, count in result.items() if name != "row_count" ], ) @@ -388,7 +394,7 @@ def do_peek(self, arg: str) -> None: ) try: result = connection.execute(query) - except Exception as exc: + except sqlalchemy.exc.SQLAlchemyError as exc: self.print(f'SQL query "{query}" caused exception {exc}') return self.print_table(list(result.keys()), result.fetchmany(max_peek_rows)) diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py index 544e0cf3..24ccb529 100644 --- a/datafaker/interactive/generators.py +++ b/datafaker/interactive/generators.py @@ -1,4 +1,4 @@ -"""Generator configuration shell.""" +"""Generator configuration shell.""" # pylint: disable=too-many-lines import functools import re from collections.abc import Iterable, Mapping, MutableMapping, Sequence @@ -11,7 +11,13 @@ from datafaker.generators import everything_factory from datafaker.generators.base import Generator, PredefinedGenerator from datafaker.interactive.base import DbCmd, TableEntry, fk_column_name, or_default -from datafaker.utils import logger, primary_private_fks, table_is_private +from datafaker.utils import ( + get_columns_assigned, + get_row_generators, + logger, + primary_private_fks, + table_is_private, +) @dataclass @@ -35,6 +41,7 @@ class GeneratorCmdTableEntry(TableEntry): new_generators: list[GeneratorInfo] +# pylint: disable=too-many-public-methods class GeneratorCmd(DbCmd): """Interactive command shell for setting generators.""" @@ -85,50 +92,41 @@ def make_table_entry( return None if table_config.get("num_rows_per_pass", 1) == 0: return None - metadata_table = self.metadata.tables[table_name] - columns = [str(colname) for colname in metadata_table.columns.keys()] + columns = [ + str(colname) for colname in self.metadata.tables[table_name].columns.keys() + ] column_set = frozenset(columns) columns_assigned_so_far: set[str] = set() new_generator_infos: list[GeneratorInfo] = [] - old_generator_infos: list[GeneratorInfo] = [] - for rg in table_config.get("row_generators", []): - gen_name = rg.get("name", None) - if gen_name: - ca = rg.get("columns_assigned", []) - collist: list[str] = ( - [ca] if isinstance(ca, str) else [str(c) for c in ca] + for gen_name, rg in get_row_generators(table_config): + colset: set[str] = set(get_columns_assigned(rg)) + for unknown in colset - column_set: + logger.warning( + "table '%s' has '%s' assigned to column '%s' which is not in this table", + table_name, + gen_name, + unknown, ) - colset: set[str] = set(collist) - for unknown in colset - column_set: - logger.warning( - "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, - gen_name, - unknown, - ) - for mult in columns_assigned_so_far & colset: - logger.warning( - "table '%s' has column '%s' assigned to multiple times", - table_name, - mult, - ) - actual_collist = [c for c in collist if c in columns] - if actual_collist: - gen = PredefinedGenerator(table_name, rg, self.config) - new_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) - ) - old_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) + for mult in columns_assigned_so_far & colset: + logger.warning( + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, + ) + actual_collist = [c for c in columns if c in colset] + if actual_collist: + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=PredefinedGenerator(table_name, rg, self.config), ) - columns_assigned_so_far |= colset + ) + columns_assigned_so_far |= colset + old_generator_infos = [ + GeneratorInfo(columns=gi.columns.copy(), gen=gi.gen) + for gi in new_generator_infos + ] for colname in columns: if colname not in columns_assigned_so_far: new_generator_infos.append( @@ -139,6 +137,7 @@ def make_table_entry( ) if len(new_generator_infos) == 0: return None + return GeneratorCmdTableEntry( name=table_name, old_generators=old_generator_infos, @@ -853,7 +852,7 @@ def do_unset(self, _arg: str) -> None: self.set_generator(None) self._go_next() - def merge_columns(self, arg: str) -> None: + def merge_columns(self, arg: str) -> bool: """ Add this column(s) to the specified column(s). @@ -877,8 +876,7 @@ def merge_columns(self, arg: str) -> None: self.print(self.ERROR_NO_SUCH_COLUMN, uc) return False gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - stated_current_columns = cols_to_merge & current_columns + stated_current_columns = cols_to_merge & frozenset(gen_info.columns) if stated_current_columns: for c in stated_current_columns: self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) @@ -911,7 +909,7 @@ def merge_columns(self, arg: str) -> None: self.set_prompt() return True - def do_merge(self, arg: str): + def do_merge(self, arg: str) -> None: """Add this column(s) to the specified column(s), so one generator covers them all.""" self.merge_columns(arg) @@ -987,7 +985,9 @@ def complete_unmerge( def get_current_columns(self) -> set[str]: """Get the current colums.""" - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return set() gen_info = table_entry.new_generators[self.generator_index] return set(gen_info.columns) diff --git a/datafaker/interactive/table.py b/datafaker/interactive/table.py index 40301b01..d763a14b 100644 --- a/datafaker/interactive/table.py +++ b/datafaker/interactive/table.py @@ -1,5 +1,5 @@ """Table configuration command shell.""" -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass from typing import Any, cast diff --git a/datafaker/make.py b/datafaker/make.py index bca7a2b7..6f4cc9bd 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -28,9 +28,11 @@ MaybeAsyncEngine, create_db_engine, download_table, + get_columns_assigned, get_flag, get_property, get_related_table_names, + get_row_generators, get_sync_engine, get_vocabulary_table_names, logger, @@ -176,27 +178,15 @@ def _get_row_generator( ) -> tuple[list[RowGeneratorInfo], list[str]]: """Get the row generators information, for the given table.""" row_gen_info: list[RowGeneratorInfo] = [] - config: list[Mapping[str, Any]] = get_property(table_config, "row_generators", []) columns_covered = [] - for gen_conf in config: - name: str = gen_conf["name"] - columns_assigned = gen_conf["columns_assigned"] + for name, gen_conf in get_row_generators(table_config): + columns_assigned = list(get_columns_assigned(gen_conf)) keyword_arguments: Mapping[str, Any] = gen_conf.get("kwargs", {}) positional_arguments: Sequence[str] = gen_conf.get("args", []) - - if isinstance(columns_assigned, str): - columns_assigned = [columns_assigned] - - variable_names: list[str] = columns_assigned - try: - columns_covered += columns_assigned - except TypeError: - # Might be a single string, rather than a list of strings. - columns_covered.append(columns_assigned) - + columns_covered += columns_assigned row_gen_info.append( RowGeneratorInfo( - variable_names=variable_names, + variable_names=columns_assigned, function_call=_get_function_call( name, positional_arguments, keyword_arguments ), diff --git a/datafaker/utils.py b/datafaker/utils.py index 3cc8c282..7ef91bff 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -9,7 +9,17 @@ from collections.abc import Mapping, Sequence from pathlib import Path from types import ModuleType -from typing import Any, Callable, Final, Generator, Iterable, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + Final, + Generator, + Generic, + Iterable, + Optional, + TypeVar, + Union, +) import psycopg2 import sqlalchemy @@ -43,6 +53,16 @@ T = TypeVar("T") +class Empty(Generic[T]): + """Generic empty sequences for default arguments.""" + + @classmethod + def iterable(cls) -> Iterable[T]: + """Get an empty iterable.""" + e: list[T] = [] + return (x for x in e) + + def read_config_file(path: str) -> dict: """Read a config file, warning if it is invalid. @@ -417,6 +437,44 @@ def get_vocabulary_table_names(config: Mapping) -> set[str]: } +def get_columns_assigned( + row_generator_config: Mapping[str, Any] +) -> Generator[str, None, None]: + """ + Get the columns assigned in a ``row_generators[n]`` stanza. + + :param generator_config: The ``row_generators[n]`` stanza itself. + """ + ca = row_generator_config.get("columns_assigned", None) + if ca is None: + return + if isinstance(ca, str): + yield ca + return + if not hasattr(ca, "__iter__"): + return + for c in ca: + yield str(c) + + +def get_row_generators( + table_config: Mapping[str, Any], +) -> Generator[tuple[str, Mapping[str, Any]], None, None]: + """ + Get the row generators from a table configuration. + + :param table_config: The element from the ``tables:`` stanza of ``config.xml``. + :return: Pair of (name, row generator config). + """ + rgs = table_config.get("row_generators", None) + if isinstance(rgs, str) or not hasattr(rgs, "__iter__"): + return + for rg in rgs: + name = rg.get("name", None) + if name: + yield (name, rg) + + def make_foreign_key_name(table_name: str, col_name: str) -> str: """Make a suitable foreign key name.""" return f"{table_name}_{col_name}_fkey" diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py index 6b4da137..a7eb757d 100644 --- a/tests/test_interactive_generators.py +++ b/tests/test_interactive_generators.py @@ -2,16 +2,11 @@ import copy import re from collections.abc import MutableMapping -from dataclasses import dataclass from typing import Any, Iterable -from unittest import TestCase -from unittest.mock import MagicMock, Mock, patch -from sqlalchemy import Connection, MetaData, insert, select +from sqlalchemy import Connection, MetaData, select -from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.generators.choice import ChoiceGeneratorFactory -from datafaker.interactive import update_config_generators from datafaker.interactive.generators import GeneratorCmd from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin @@ -565,6 +560,22 @@ def test_empty_tables_are_not_configured(self) -> None: self.assertNotIn("string", table_names) +class ChoiceMeasurementTableStats: + """Measure the data in the ``choice.sql`` schema.""" + + def __init__(self, metadata: MetaData, connection: Connection): + """Get the data and do the analysis.""" + stmt = select(metadata.tables["number_table"]) + rows = connection.execute(stmt).fetchall() + self.ones: set[int] = set() + self.twos: set[int] = set() + self.threes: set[int] = set() + for row in rows: + self.ones.add(row.one) + self.twos.add(row.two) + self.threes.add(row.three) + + class GeneratorsOutputTests(GeneratesDBTestCase): """Testing choice generation.""" @@ -580,14 +591,16 @@ def setUp(self) -> None: def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: + gc.reset() + gc.do_propose("") + return gc.get_proposals() + def test_create_with_sampled_choice(self) -> None: """Test that suppression works for choice and zipf_choice.""" - table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) self.assertIn("dist_gen.choice", proposals) self.assertIn("dist_gen.zipf_choice", proposals) self.assertIn("dist_gen.choice [sampled]", proposals) @@ -596,9 +609,7 @@ def test_create_with_sampled_choice(self) -> None: self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) gc.do_set(str(proposals["dist_gen.choice [sampled and suppressed]"][0])) gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) self.assertIn("dist_gen.choice", proposals) self.assertIn("dist_gen.zipf_choice", proposals) self.assertIn("dist_gen.choice [sampled]", proposals) @@ -609,9 +620,7 @@ def test_create_with_sampled_choice(self) -> None: str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0]) ) gc.do_next("number_table.three") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) self.assertIn("dist_gen.choice", proposals) self.assertIn("dist_gen.zipf_choice", proposals) self.assertIn("dist_gen.choice [sampled]", proposals) @@ -621,34 +630,22 @@ def test_create_with_sampled_choice(self) -> None: gc.do_set(str(proposals["dist_gen.choice [sampled]"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - ones = set() - twos = set() - threes = set() - for row in rows: - ones.add(row.one) - twos.add(row.two) - threes.add(row.three) # all generation possibilities should be present - self.assertSetEqual(ones, {1, 4}) - self.assertSetEqual(twos, {2, 3}) - self.assertSetEqual(threes, {1, 2, 3, 4, 5}) + with self.sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {2, 3}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) def test_create_with_choice(self) -> None: """Smoke test normal choice works.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) gc.do_set(str(proposals["dist_gen.choice"][0])) gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) @@ -668,13 +665,14 @@ def test_create_with_weighted_choice(self) -> None: """Smoke test weighted choice.""" with self._get_cmd({}) as gc: gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn( - "dist_gen.weighted_choice [sampled and suppressed]", proposals + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [sampled and suppressed]", + }, + set(proposals), ) prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] self.assert_subset(set(prop[2]), {"1", "4"}) @@ -688,13 +686,14 @@ def test_create_with_weighted_choice(self) -> None: self.assert_subset(col_set, {1, 4}) gc.do_set(str(prop[0])) gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn( - "dist_gen.weighted_choice [sampled and suppressed]", proposals + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [sampled and suppressed]", + }, + set(proposals), ) prop = proposals["dist_gen.weighted_choice"] self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) @@ -706,11 +705,14 @@ def test_create_with_weighted_choice(self) -> None: self.assert_subset(col_set2, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_next("number_table.three") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + }, + set(proposals), + ) self.assertNotIn( "dist_gen.weighted_choice [sampled and suppressed]", proposals ) @@ -725,19 +727,12 @@ def test_create_with_weighted_choice(self) -> None: gc.do_quit("") self.generate_data(gc.config, num_passes=200) with self.sync_engine.connect() as conn: - ones = set() - twos = set() - threes = set() - for row in conn.execute( - select(self.metadata.tables["number_table"]) - ).fetchall(): - ones.add(row.one) - twos.add(row.two) - threes.add(row.three) - # all generation possibilities should be present - self.assertSetEqual(ones, {1, 4}) - self.assertSetEqual(twos, {1, 2, 3, 4, 5}) - self.assertSetEqual(threes, {1, 2, 3, 4, 5}) + with self.sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + # all generation possibilities should be present + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) class GeneratorTests(GeneratesDBTestCase): @@ -864,582 +859,3 @@ def test_varchar_ns_are_truncated(self) -> None: stmt = select(self.metadata.tables[table].c[column]) rows = conn.execute(stmt).scalars().fetchall() self.assert_are_truncated_to(rows, 20) - - -@dataclass -class Stat: - """Mean and variance calculator.""" - - n: int = 0 - x: float = 0 - x2: float = 0 - - def add(self, x: float) -> None: - """Add one datum.""" - self.n += 1 - self.x += x - self.x2 += x * x - - def count(self) -> int: - """Get the number of data added.""" - return self.n - - def x_mean(self) -> float: - """Get the mean of the added data.""" - return self.x / self.n - - def x_var(self) -> float: - """Get the variance of the added data.""" - x = self.x - return (self.x2 - x * x / self.n) / (self.n - 1) - - -@dataclass -class Correlation(Stat): - """Mean, variance and covariance.""" - - y: float = 0 - y2: float = 0 - xy: float = 0 - - def add2(self, x: float, y: float) -> None: - """Add a 2D data point.""" - self.n += 1 - self.x += x - self.x2 += x * x - self.y += y - self.y2 += y * y - self.xy += x * y - - def y_mean(self) -> float: - """Get the mean of the second parts of the added points.""" - return self.y / self.n - - def y_var(self) -> float: - """Get the variance of the second parts of the added points.""" - y = self.y - return (self.y2 - y * y / self.n) / (self.n - 1) - - def covar(self) -> float: - """Get the covariance of the two parts of the added points.""" - return (self.xy - self.x * self.y / self.n) / (self.n - 1) - - -# pylint disable: too-many-instance-attributes -class EavMeasurementTableStats: - """The statistics for the Measurement table of eav.sql.""" - - def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: - stmt = select(metadata.tables["measurement"]) - rows = conn.execute(stmt).fetchall() - self.types: set[int] = set() - self.one_count = 0 - self.one_yes_count = 0 - self.two = Correlation() - self.three = Correlation() - self.four = Correlation() - self.fish = Stat() - self.fowl = Stat() - for row in rows: - self.types.add(row.type) - if row.type == 1: - # yes or no - test.assertIsNone(row.first_value) - test.assertIsNone(row.second_value) - test.assertIn(row.third_value, {"yes", "no"}) - self.one_count += 1 - if row.third_value == "yes": - self.one_yes_count += 1 - elif row.type == 2: - # positive correlation around 1.4, 1.8 - test.assertIsNotNone(row.first_value) - test.assertIsNotNone(row.second_value) - test.assertIsNone(row.third_value) - self.two.add2(row.first_value, row.second_value) - elif row.type == 3: - # negative correlation around 11.8, 12.1 - test.assertIsNotNone(row.first_value) - test.assertIsNotNone(row.second_value) - test.assertIsNone(row.third_value) - self.three.add2(row.first_value, row.second_value) - elif row.type == 4: - # positive correlation around 21.4, 23.4 - test.assertIsNotNone(row.first_value) - test.assertIsNotNone(row.second_value) - test.assertIsNone(row.third_value) - self.four.add2(row.first_value, row.second_value) - elif row.type == 5: - test.assertIn(row.third_value, {"fish", "fowl"}) - test.assertIsNotNone(row.first_value) - test.assertIsNone(row.second_value) - if row.third_value == "fish": - # mean 8.1 and sd 0.755 - self.fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - self.fowl.add(row.first_value) - - -class NullPartitionedTests(GeneratesDBTestCase): - """Testing null-partitioned grouped multivariate generation.""" - - dump_file_path = "eav.sql" - database_name = "eav" - schema_name = "public" - - def setUp(self) -> None: - """Set up the test with specific sample and suppress counts.""" - super().setUp() - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 - NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: - """Get the configure-generators object as our command.""" - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_null_partitioned_grouped_multivariate(self) -> None: - """Test EAV for all columns.""" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns( - gc, - "measurement", - [ - "type", - "first_value", - "second_value", - "third_value", - ], - ) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - gc.reset() - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: - stats = EavMeasurementTableStats(conn, self.metadata, self) - # type 1 - self.assertAlmostEqual( - stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 - ) - # type 2 - self.assertAlmostEqual( - stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 - ) - self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) - self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) - self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) - self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) - # type 3 - self.assertAlmostEqual( - stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) - # type 4 - self.assertAlmostEqual( - stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) - # type 5/fish - self.assertAlmostEqual( - stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) - # type 5/fowl - self.assertAlmostEqual( - stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) - - def populate_measurement_type_vocab(self): - """Add a vocab table without messing around with files""" - table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: - conn.execute(insert(table).values({"id": 1, "name": "agreement"})) - conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) - conn.execute(insert(table).values({"id": 3, "name": "velocity"})) - conn.execute(insert(table).values({"id": 4, "name": "position"})) - conn.execute(insert(table).values({"id": 5, "name": "matter"})) - conn.commit() - - def merge_columns( - self, gc: TestGeneratorCmd, table: str, columns: list[str] - ) -> None: - """Merge columns in a table""" - gc.do_next(f"{table}.{columns[0]}") - for col in columns[1:]: - gc.do_merge(col) - gc.reset() - - def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: - """Test EAV for all columns with sampled and suppressed generation.""" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns( - gc, - "measurement", - [ - "type", - "first_value", - "second_value", - "third_value", - ], - ) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn( - "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", - proposals, - ) - dist_to_choose = ( - "null-partitioned grouped_multivariate_normal [sampled and suppressed]" - ) - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - self.merge_columns( - gc, - "observation", - [ - "type", - "first_value", - "second_value", - "third_value", - ], - ) - gc.do_propose("") - proposals = gc.get_proposals() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: - stats = EavMeasurementTableStats(conn, self.metadata, self) - stmt = select(self.metadata.tables["observation"]) - rows = conn.execute(stmt).fetchall() - firsts = Stat() - for row in rows: - stats.types.add(row.type) - self.assertEqual(row.type, 1) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {"ham", "eggs"}) - firsts.add(row.first_value) - self.assertEqual(firsts.count(), 800) - self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) - self.assert_subset(stats.types, {1, 2, 3, 4, 5}) - self.assertEqual(len(stats.types), 4) - self.assert_subset({1, 5}, stats.types) - # type 1 - self.assertAlmostEqual( - stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 - ) - # type 5/fish - self.assertAlmostEqual( - stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) - # type 5/fowl - self.assertAlmostEqual( - stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) - - def test_create_with_null_partitioned_grouped_sampled_only(self): - """Test EAV for all columns with sampled generation but no suppression.""" - table_name = "measurement" - table2_name = "observation" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns( - gc, table_name, ["type", "first_value", "second_value", "third_value"] - ) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn( - "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", - proposals, - ) - self.assertIn( - "null-partitioned grouped_multivariate_normal [sampled and suppressed]", - proposals, - ) - self.assertIn( - "null-partitioned grouped_multivariate_lognormal [sampled]", proposals - ) - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - self.merge_columns( - gc, table2_name, ["type", "first_value", "second_value", "third_value"] - ) - gc.do_propose("") - proposals = gc.get_proposals() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - self.assert_subset({row.type for row in rows}, {1, 2, 3, 4, 5}) - stmt = select(self.metadata.tables[table2_name]) - rows = conn.execute(stmt).fetchall() - self.assertEqual( - {row.third_value for row in rows}, {"ham", "eggs", "cheese"} - ) - - def test_create_with_null_partitioned_grouped_sampled_tiny(self): - """ - Test EAV for all columns with sampled generation that only gets a tiny sample. - """ - # five will ensure that at least one group will have two elements in it, - # but all three cannot. - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 - table_name = "observation" - generate_count = 100 - with self._get_cmd({}) as gc: - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" - self.merge_columns( - gc, table_name, ["type", "first_value", "second_value", "third_value"] - ) - gc.do_propose("") - proposals = gc.get_proposals() - # breakpoint() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - # we should only have one or two of "ham", "eggs" and "cheese" represented - foods = {row.third_value for row in rows} - self.assert_subset(foods, {"ham", "eggs", "cheese"}) - self.assertLess(len(foods), 3) - - -class NonInteractiveTests(RequiresDBTestCase): - """ - Test the --spec SPEC_FILE option of configure-generators - """ - - dump_file_path = "eav.sql" - database_name = "eav" - schema_name = "public" - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - ["observation", "type", "dist_gen.weighted_choice [sampled]"], - [ - "observation", - "first_value", - "dist_gen.weighted_choice", - "dist_gen.constant", - ], - [ - "observation", - "second_value", - "dist_gen.weighted_choice", - "dist_gen.weighted_choice [sampled]", - "dist_gen.constant", - ], - ["observation", "third_value", "dist_gen.weighted_choice"], - ] - ), - ) - def test_non_interactive_configure_generators( - self, _mock_csv_reader: MagicMock, _mock_path: MagicMock - ) -> None: - """ - test that we can set generators from a CSV file - """ - config: MutableMapping[str, Any] = {} - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg["name"] - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") - self.assertEqual( - row_gens["observation['first_value']"], "dist_gen.weighted_choice" - ) - self.assertEqual(row_gens["observation['second_value']"], "dist_gen.constant") - self.assertEqual( - row_gens["observation['third_value']"], "dist_gen.weighted_choice" - ) - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - [ - "observation", - "type first_value second_value third_value", - "null-partitioned grouped_multivariate_lognormal", - ], - ] - ), - ) - def test_non_interactive_configure_null_partitioned( - self, mock_csv_reader: MagicMock, mock_path: MagicMock - ): - """ - test that we can set multi-column generators from a CSV file - """ - config = {} - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["name"], - "dist_gen.alternatives", - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["name"], - '"with_constants_at"', - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], - '"grouped_multivariate_lognormal"', - ) - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - [ - "observation", - "type first_value second_value third_value", - "null-partitioned grouped_multivariate_lognormal", - ], - ] - ), - ) - def test_non_interactive_configure_null_partitioned_where_existing_merges( - self, _mock_csv_reader: MagicMock, _mock_path: MagicMock - ) -> None: - """ - test that we can set multi-column generators from a CSV file, - but where there are already multi-column generators configured - that will have to be unmerged. - """ - config = { - "tables": { - "observation": { - "row_generators": [ - { - "name": "arbitrary_gen", - "columns_assigned": [ - "type", - "second_value", - "first_value", - ], - } - ], - }, - }, - } - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["name"], - "dist_gen.alternatives", - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["name"], - '"with_constants_at"', - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], - '"grouped_multivariate_lognormal"', - ) diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py new file mode 100644 index 00000000..3d5c3126 --- /dev/null +++ b/tests/test_interactive_generators_partitioned.py @@ -0,0 +1,419 @@ +"""Tests for null-partitioned generators.""" +from collections.abc import MutableMapping +from dataclasses import dataclass +from typing import Any +from unittest import TestCase + +from sqlalchemy import Connection, MetaData, insert, select + +from datafaker.generators import NullPartitionedNormalGeneratorFactory +from tests.test_interactive_generators import TestGeneratorCmd +from tests.utils import GeneratesDBTestCase + + +@dataclass +class Stat: + """Mean and variance calculator.""" + + n: int = 0 + x: float = 0 + x2: float = 0 + + def add(self, x: float) -> None: + """Add one datum.""" + self.n += 1 + self.x += x + self.x2 += x * x + + def count(self) -> int: + """Get the number of data added.""" + return self.n + + def x_mean(self) -> float: + """Get the mean of the added data.""" + return self.x / self.n + + def x_var(self) -> float: + """Get the variance of the added data.""" + x = self.x + return (self.x2 - x * x / self.n) / (self.n - 1) + + +@dataclass +class Correlation(Stat): + """Mean, variance and covariance.""" + + y: float = 0 + y2: float = 0 + xy: float = 0 + + def add2(self, x: float, y: float) -> None: + """Add a 2D data point.""" + self.n += 1 + self.x += x + self.x2 += x * x + self.y += y + self.y2 += y * y + self.xy += x * y + + def y_mean(self) -> float: + """Get the mean of the second parts of the added points.""" + return self.y / self.n + + def y_var(self) -> float: + """Get the variance of the second parts of the added points.""" + y = self.y + return (self.y2 - y * y / self.n) / (self.n - 1) + + def covar(self) -> float: + """Get the covariance of the two parts of the added points.""" + return (self.xy - self.x * self.y / self.n) / (self.n - 1) + + +# pylint: disable=too-many-instance-attributes +class EavMeasurementTableStats: + """The statistics for the Measurement table of eav.sql.""" + + def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: + stmt = select(metadata.tables["measurement"]) + rows = conn.execute(stmt).fetchall() + self.types: set[int] = set() + self.one_count = 0 + self.one_yes_count = 0 + self.two = Correlation() + self.three = Correlation() + self.four = Correlation() + self.fish = Stat() + self.fowl = Stat() + for row in rows: + self.types.add(row.type) + if row.type == 1: + # yes or no + test.assertIsNone(row.first_value) + test.assertIsNone(row.second_value) + test.assertIn(row.third_value, {"yes", "no"}) + self.one_count += 1 + if row.third_value == "yes": + self.one_yes_count += 1 + elif row.type == 2: + # positive correlation around 1.4, 1.8 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.two.add2(row.first_value, row.second_value) + elif row.type == 3: + # negative correlation around 11.8, 12.1 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.three.add2(row.first_value, row.second_value) + elif row.type == 4: + # positive correlation around 21.4, 23.4 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.four.add2(row.first_value, row.second_value) + elif row.type == 5: + test.assertIn(row.third_value, {"fish", "fowl"}) + test.assertIsNotNone(row.first_value) + test.assertIsNone(row.second_value) + if row.third_value == "fish": + # mean 8.1 and sd 0.755 + self.fish.add(row.first_value) + else: + # mean 11.2 and sd 1.114 + self.fowl.add(row.first_value) + + +class NullPartitionedTests(GeneratesDBTestCase): + """Testing null-partitioned grouped multivariate generation.""" + + dump_file_path = "eav.sql" + database_name = "eav" + schema_name = "public" + + def setUp(self) -> None: + """Set up the test with specific sample and suppress counts.""" + super().setUp() + NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 + NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the configure-generators object as our command.""" + return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + + def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: + gc.reset() + gc.do_propose("") + return gc.get_proposals() + + def test_create_with_null_partitioned_grouped_multivariate(self) -> None: + """Test EAV for all columns.""" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, + "measurement", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) + dist_to_choose = "null-partitioned grouped_multivariate_normal" + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + gc.reset() + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stats = EavMeasurementTableStats(conn, self.metadata, self) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 2 + self.assertAlmostEqual( + stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) + self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) + self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) + self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) + # type 3 + self.assertAlmostEqual( + stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) + # type 4 + self.assertAlmostEqual( + stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + + def populate_measurement_type_vocab(self) -> None: + """Add a vocab table without messing around with files""" + table = self.metadata.tables["measurement_type"] + with self.sync_engine.connect() as conn: + conn.execute(insert(table).values({"id": 1, "name": "agreement"})) + conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) + conn.execute(insert(table).values({"id": 3, "name": "velocity"})) + conn.execute(insert(table).values({"id": 4, "name": "position"})) + conn.execute(insert(table).values({"id": 5, "name": "matter"})) + conn.commit() + + def merge_columns( + self, gc: TestGeneratorCmd, table: str, columns: list[str] + ) -> None: + """Merge columns in a table""" + gc.do_next(f"{table}.{columns[0]}") + for col in columns[1:]: + gc.do_merge(col) + gc.reset() + + def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: + """Test EAV for all columns with sampled and suppressed generation.""" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, + "measurement", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + self.assert_subset( + { + "null-partitioned grouped_multivariate_lognormal", + "null-partitioned grouped_multivariate_normal", + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + }, + set(proposals), + ) + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + self.merge_columns( + gc, + "observation", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stats = EavMeasurementTableStats(conn, self.metadata, self) + stmt = select(self.metadata.tables["observation"]) + rows = conn.execute(stmt).fetchall() + firsts = Stat() + for row in rows: + stats.types.add(row.type) + self.assertEqual(row.type, 1) + self.assertIsNotNone(row.first_value) + self.assertIsNone(row.second_value) + self.assertIn(row.third_value, {"ham", "eggs"}) + firsts.add(row.first_value) + self.assertEqual(firsts.count(), 800) + self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) + self.assert_subset(stats.types, {1, 2, 3, 4, 5}) + self.assertEqual(len(stats.types), 4) + self.assert_subset({1, 5}, stats.types) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + + def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: + """Test EAV for all columns with sampled generation but no suppression.""" + table_name = "measurement" + table2_name = "observation" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, table_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) + self.assertIn("null-partitioned grouped_multivariate_normal", proposals) + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + proposals, + ) + self.assertIn( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]", + proposals, + ) + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled]", proposals + ) + dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + self.merge_columns( + gc, table2_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + self.assert_subset({row.type for row in rows}, {1, 2, 3, 4, 5}) + stmt = select(self.metadata.tables[table2_name]) + rows = conn.execute(stmt).fetchall() + self.assertEqual( + {row.third_value for row in rows}, {"ham", "eggs", "cheese"} + ) + + def test_create_with_null_partitioned_grouped_sampled_tiny(self) -> None: + """ + Test EAV for all columns with sampled generation that only gets a tiny sample. + """ + # five will ensure that at least one group will have two elements in it, + # but all three cannot. + NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 + table_name = "observation" + generate_count = 100 + with self._get_cmd({}) as gc: + dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" + self.merge_columns( + gc, table_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + # we should only have one or two of "ham", "eggs" and "cheese" represented + foods = {row.third_value for row in rows} + self.assert_subset(foods, {"ham", "eggs", "cheese"}) + self.assertLess(len(foods), 3) diff --git a/tests/test_noninteractive_generators.py b/tests/test_noninteractive_generators.py new file mode 100644 index 00000000..93431147 --- /dev/null +++ b/tests/test_noninteractive_generators.py @@ -0,0 +1,179 @@ +""" Tests for the configure-generators command with the --spec option. """ + +from collections.abc import Mapping, MutableMapping +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +from datafaker.interactive import update_config_generators +from tests.utils import RequiresDBTestCase + + +class NonInteractiveTests(RequiresDBTestCase): + """ + Test the --spec SPEC_FILE option of configure-generators + """ + + dump_file_path = "eav.sql" + database_name = "eav" + schema_name = "public" + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + ["observation", "type", "dist_gen.weighted_choice [sampled]"], + [ + "observation", + "first_value", + "dist_gen.weighted_choice", + "dist_gen.constant", + ], + [ + "observation", + "second_value", + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.constant", + ], + ["observation", "third_value", "dist_gen.weighted_choice"], + ] + ), + ) + def test_non_interactive_configure_generators( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set generators from a CSV file + """ + config: MutableMapping[str, Any] = {} + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens = { + f"{table}{sorted(rg['columns_assigned'])}": rg["name"] + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") + self.assertEqual( + row_gens["observation['first_value']"], "dist_gen.weighted_choice" + ) + self.assertEqual(row_gens["observation['second_value']"], "dist_gen.constant") + self.assertEqual( + row_gens["observation['third_value']"], "dist_gen.weighted_choice" + ) + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + [ + "observation", + "type first_value second_value third_value", + "null-partitioned grouped_multivariate_lognormal", + ], + ] + ), + ) + def test_non_interactive_configure_null_partitioned( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set multi-column generators from a CSV file + """ + config: MutableMapping[str, Any] = {} + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens = { + f"{table}{sorted(rg['columns_assigned'])}": rg + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["name"], + "dist_gen.alternatives", + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["name"], + '"with_constants_at"', + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], + '"grouped_multivariate_lognormal"', + ) + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + [ + "observation", + "type first_value second_value third_value", + "null-partitioned grouped_multivariate_lognormal", + ], + ] + ), + ) + def test_non_interactive_configure_null_partitioned_where_existing_merges( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set multi-column generators from a CSV file, + but where there are already multi-column generators configured + that will have to be unmerged. + """ + config = { + "tables": { + "observation": { + "row_generators": [ + { + "name": "arbitrary_gen", + "columns_assigned": [ + "type", + "second_value", + "first_value", + ], + } + ], + }, + }, + } + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens: Mapping[str, Any] = { + f"{table}{sorted(rg['columns_assigned'])}": rg + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["name"], + "dist_gen.alternatives", + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["name"], + '"with_constants_at"', + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], + '"grouped_multivariate_lognormal"', + ) From 05ea3780c67eeefe6a744b1c8127b2087934b3d8 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 15 Oct 2025 18:18:52 +0100 Subject: [PATCH 22/44] Add running tests to pre-commit.yml --- .github/workflows/pre-commit.yml | 4 ++++ CONTRIBUTING.md | 11 ----------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b07de4d1..38951218 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -59,3 +59,7 @@ jobs: shell: bash run: | pre-commit run --all-files + - name: Run tests + shell: bash + run: | + python -m unittest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8d7b8799..234f8fb8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,17 +47,6 @@ Executing unit tests is straightforward: python -m unittest discover --verbose tests/ ``` -for tests that are currently maintained. - -## Running functional tests - -These tests do not currently work, and will be replaced by unit tests. - -Functional tests require PostgreSQL to be installed. - -..warning:: - Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file. - ## Building documentation locally ```bash From 91036ceff87bfe5b82f904dfdd0563e27ee8d70e Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 15 Oct 2025 18:25:37 +0100 Subject: [PATCH 23/44] Github actions starting PostgreSQL --- .github/workflows/pre-commit.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 38951218..57733147 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -59,6 +59,10 @@ jobs: shell: bash run: | pre-commit run --all-files + - name: Start PostgreSQL + shell: bash + run: | + sudo systemctl start postgresql.service - name: Run tests shell: bash run: | From 69e0933bc059f586a6650b4a96de778748f9948f Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:08:53 +0100 Subject: [PATCH 24/44] Fixed test_unique_constraint_fails --- .pre-commit-config.yaml | 5 +---- .pylintrc | 3 +-- tests/test_functional.py | 9 ++++++++- tests/test_utils.py | 12 +++++++++--- tests/workspace/.gitignore | 1 - tests/workspace/README.md | 3 --- 6 files changed, 19 insertions(+), 14 deletions(-) delete mode 100644 tests/workspace/.gitignore delete mode 100644 tests/workspace/README.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04464f9d..a99928fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,8 +40,7 @@ repos: language: system types: ['python'] exclude: (?x)( - tests/examples| - tests/workspace + tests/examples ) - id: isort name: isort @@ -50,7 +49,6 @@ repos: types: ['python'] exclude: (?x)( tests/examples| - tests/workspace| examples ) - id: pylint @@ -77,7 +75,6 @@ repos: language: system exclude: (?x)( tests/examples| - tests/workspace| examples ) types: ['python'] diff --git a/.pylintrc b/.pylintrc index 22a92bd7..cb276e25 100644 --- a/.pylintrc +++ b/.pylintrc @@ -24,8 +24,7 @@ ignore=CVS # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths. -ignore-paths=tests/examples, - tests/workspace +ignore-paths=tests/examples # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/tests/test_functional.py b/tests/test_functional.py index ac7e51a7..a7374fbb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,6 +2,7 @@ import os import shutil from pathlib import Path +import tempfile from typing import Any, Mapping from sqlalchemy import create_engine, inspect @@ -20,7 +21,6 @@ class DBFunctionalTestCase(RequiresDBTestCase): database_name = "src" schema_name = "public" - test_dir = Path("tests/workspace") examples_dir = Path("tests/examples") orm_file_path = Path("orm.yaml") @@ -67,6 +67,7 @@ def setUp(self) -> None: ) # Copy some of the example files over to the workspace. + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) for file in self.generator_file_paths + (self.config_file_path,): src = self.examples_dir / file dst = self.test_dir / file @@ -501,6 +502,12 @@ def test_unique_constraint_fail(self) -> None: f"--orm-file={self.alt_orm_file_path}", "--force", ) + self.invoke( + "make-vocab", + f"--orm-file={self.alt_orm_file_path}", + f"--config-file={self.config_file_path}", + "--force", + ) self.invoke( "make-stats", f"--stats-file={self.stats_file_path}", diff --git a/tests/test_utils.py b/tests/test_utils.py index 54c167a2..808186fc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,9 @@ """Tests for the utils module.""" +import importlib import os import sys from pathlib import Path +import tempfile from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert @@ -14,6 +16,7 @@ read_config_file, ) from tests.utils import DatafakerTestCase, RequiresDBTestCase +from . import examples # pylint: disable=invalid-name Base = declarative_base() @@ -60,7 +63,6 @@ class TestDownload(RequiresDBTestCase): dump_file_path = "providers.dump" mytable_file_path = Path("mytable.yaml") - test_dir = Path("tests/workspace") start_dir = os.getcwd() def setUp(self) -> None: @@ -69,6 +71,7 @@ def setUp(self) -> None: metadata.create_all(self.engine) + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) os.chdir(self.test_dir) self.mytable_file_path.unlink(missing_ok=True) @@ -90,8 +93,11 @@ def test_download_table(self) -> None: ) # The .strip() gets rid of any possible empty lines at the end of the file. - with Path("../examples/expected.yaml").open(encoding="utf-8") as yamlfile: - expected = yamlfile.read().strip() + with importlib.resources.as_file( + importlib.resources.files(examples) / "expected.yaml" + ) as yamlpath: + with yamlpath.open() as yamlfile: + expected = yamlfile.read().strip() with self.mytable_file_path.open(encoding="utf-8") as yamlfile: actual = yamlfile.read().strip() diff --git a/tests/workspace/.gitignore b/tests/workspace/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/tests/workspace/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/tests/workspace/README.md b/tests/workspace/README.md deleted file mode 100644 index 8165a69a..00000000 --- a/tests/workspace/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Test Workspace - -A workspace for the functional tests to run in. From 820700e637dbcca58ae4553ae1cd7fe100524ffc Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:18:41 +0100 Subject: [PATCH 25/44] Cleaned up --- tests/test_functional.py | 2 +- tests/test_utils.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index a7374fbb..dc7ef48a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,8 +1,8 @@ """Tests for the CLI.""" import os import shutil -from pathlib import Path import tempfile +from pathlib import Path from typing import Any, Mapping from sqlalchemy import create_engine, inspect diff --git a/tests/test_utils.py b/tests/test_utils.py index 808186fc..aab9ee04 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,9 @@ """Tests for the utils module.""" -import importlib import os import sys -from pathlib import Path import tempfile +from importlib import resources +from pathlib import Path from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert @@ -16,6 +16,7 @@ read_config_file, ) from tests.utils import DatafakerTestCase, RequiresDBTestCase + from . import examples # pylint: disable=invalid-name @@ -93,10 +94,8 @@ def test_download_table(self) -> None: ) # The .strip() gets rid of any possible empty lines at the end of the file. - with importlib.resources.as_file( - importlib.resources.files(examples) / "expected.yaml" - ) as yamlpath: - with yamlpath.open() as yamlfile: + with resources.as_file(resources.files(examples) / "expected.yaml") as yamlpath: + with yamlpath.open(encoding="utf-8") as yamlfile: expected = yamlfile.read().strip() with self.mytable_file_path.open(encoding="utf-8") as yamlfile: From 433bb19e073cc4f4863b219de9d0839c6212138b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:46:06 +0100 Subject: [PATCH 26/44] Fixed tests --- tests/test_interactive_generators_partitioned.py | 4 ++-- tests/test_utils.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index 3d5c3126..f544be81 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -196,8 +196,8 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) - self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.08) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.08) # type 3 self.assertAlmostEqual( stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 diff --git a/tests/test_utils.py b/tests/test_utils.py index aab9ee04..0d2427e7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ """Tests for the utils module.""" +import importlib.util import os import sys import tempfile @@ -17,8 +18,6 @@ ) from tests.utils import DatafakerTestCase, RequiresDBTestCase -from . import examples - # pylint: disable=invalid-name Base = declarative_base() # pylint: enable=invalid-name @@ -94,7 +93,10 @@ def test_download_table(self) -> None: ) # The .strip() gets rid of any possible empty lines at the end of the file. - with resources.as_file(resources.files(examples) / "expected.yaml") as yamlpath: + tests_module = sys.modules["tests"] + with resources.as_file( + resources.files(tests_module) / "examples" / "expected.yaml" + ) as yamlpath: with yamlpath.open(encoding="utf-8") as yamlfile: expected = yamlfile.read().strip() From d1b07dc7d35cb3012ce36bbb387d1bf645351aee Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:52:29 +0100 Subject: [PATCH 27/44] cleaned --- tests/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d2427e7..ac82d124 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,4 @@ """Tests for the utils module.""" -import importlib.util import os import sys import tempfile From 79990a1c08252b06f19fbdbffa1bf4df5ab71599 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:06:37 +0100 Subject: [PATCH 28/44] Move real test runner to tests.yml, overwriting bad test runner --- .github/workflows/pre-commit.yml | 10 +------- .github/workflows/tests.yml | 44 ++++---------------------------- 2 files changed, 6 insertions(+), 48 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 57733147..ec27ef86 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -11,7 +11,7 @@ env: PRE_COMMIT_HOME: ~/.caches/pre-commit PYTHON_VERSION: "3.10" jobs: - the_job: + clean-check: runs-on: ubuntu-latest steps: - name: Checkout Code @@ -59,11 +59,3 @@ jobs: shell: bash run: | pre-commit run --all-files - - name: Start PostgreSQL - shell: bash - run: | - sudo systemctl start postgresql.service - - name: Run tests - shell: bash - run: | - python -m unittest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 75f45f03..9d79e811 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,48 +10,14 @@ env: # This should be the default but we'll be explicit PYTHON_VERSION: "3.10" jobs: - the_job: + unit-tests: runs-on: ubuntu-latest - services: - postgres: - image: postgres:15 - env: - POSTGRES_PASSWORD: password - ports: - - 5432:5432 steps: - - name: Checkout Code - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - name: Bootstrap poetry + - name: Start PostgreSQL shell: bash run: | - python -m ensurepip - python -m pip install --upgrade pip - python -m pip install poetry - - name: Configure poetry + sudo systemctl start postgresql.service + - name: Run tests shell: bash run: | - python -m poetry config virtualenvs.in-project true - # - name: Cache Poetry dependencies - # uses: actions/cache@v3 - # id: poetry-cache - # with: - # path: .venv - # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} - - name: Install dependencies - shell: bash - if: steps.poetry-cache.outputs.cache-hit != 'true' - run: | - python -m poetry install --all-extras - - name: Create src database - shell: bash - run: | - PGPASSWORD=password psql --host=localhost --username=postgres --set="ON_ERROR_STOP=1" --file=tests/examples/src.dump - - name: Run Unit Tests - shell: bash - run: | - REQUIRES_DB=1 poetry run python -m unittest discover --verbose tests + python -m unittest From 9d0ec4775d60154de9ecd036ea3ae314d75068f5 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:10:57 +0100 Subject: [PATCH 29/44] Added poetry initialisation to test runner --- .github/workflows/tests.yml | 20 ++++++++++++++++++++ tests/test_functional.py | 3 --- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d79e811..8a9b3108 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,26 @@ jobs: shell: bash run: | sudo systemctl start postgresql.service + - name: Bootstrap poetry + shell: bash + run: | + python -m ensurepip + python -m pip install --upgrade pip + python -m pip install poetry + - name: Configure poetry + shell: bash + run: | + python -m poetry config virtualenvs.in-project true + # - name: Cache Poetry dependencies uses: actions/cache@v3 + # id: poetry-cache + # with: + # path: .venv + # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} + - name: Install dependencies + shell: bash + if: steps.poetry-cache.outputs.cache-hit != 'true' + run: | + python -m poetry install --all-extras - name: Run tests shell: bash run: | diff --git a/tests/test_functional.py b/tests/test_functional.py index dc7ef48a..bfb2f096 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -29,9 +29,6 @@ class DBFunctionalTestCase(RequiresDBTestCase): alt_orm_file_path = Path("my_orm.yaml") alt_datafaker_file_path = Path("my_df.py") - vocabulary_file_paths = tuple( - map(Path, ("concept.yaml", "concept_type.yaml", "mitigation_type.yaml")), - ) generator_file_paths = tuple( map(Path, ("story_generators.py", "row_generators.py")), ) From 83438bb56a276a91578adf8e49f907086e8e6684 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:47:43 +0100 Subject: [PATCH 30/44] More test fixes --- .github/workflows/tests.yml | 13 +++---------- tests/test_interactive_generators_partitioned.py | 6 +++--- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8a9b3108..964d9436 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,21 +17,14 @@ jobs: shell: bash run: | sudo systemctl start postgresql.service - - name: Bootstrap poetry + - name: Install poetry shell: bash run: | - python -m ensurepip - python -m pip install --upgrade pip - python -m pip install poetry + sudo apt install python3-poetry - name: Configure poetry shell: bash run: | python -m poetry config virtualenvs.in-project true - # - name: Cache Poetry dependencies uses: actions/cache@v3 - # id: poetry-cache - # with: - # path: .venv - # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} - name: Install dependencies shell: bash if: steps.poetry-cache.outputs.cache-hit != 'true' @@ -40,4 +33,4 @@ jobs: - name: Run tests shell: bash run: | - python -m unittest + poetry run python -m unittest diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index f544be81..3b21b535 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -207,7 +207,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: self.assertAlmostEqual( stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1.3) # type 5/fish self.assertAlmostEqual( stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 @@ -219,7 +219,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + self.assertAlmostEqual(stats.fowl.x_var(), 1.24, delta=0.6) def populate_measurement_type_vocab(self) -> None: """Add a vocab table without messing around with files""" @@ -330,7 +330,7 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + self.assertAlmostEqual(stats.fowl.x_var(), 1.24, delta=0.6) def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: """Test EAV for all columns with sampled generation but no suppression.""" From 2306b12dd0490a0231b35ff9f4a9b318ad3e64aa Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:55:54 +0100 Subject: [PATCH 31/44] Another attempt to get tests.yml working --- .github/workflows/tests.yml | 2 ++ .pre-commit-config.yaml | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 964d9436..389e0734 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,6 +13,8 @@ jobs: unit-tests: runs-on: ubuntu-latest steps: + - name: Checkout Code + uses: actions/checkout@v3 - name: Start PostgreSQL shell: bash run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a99928fa..085ce544 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude: docs/(source|build/html)/_static/ @@ -13,12 +13,12 @@ repos: - id: check-added-large-files - repo: https://github.com/markdownlint/markdownlint # Note the "v" - rev: v0.11.0 + rev: v0.12.0 hooks: - id: markdownlint args: [--style=mdl_style.rb] - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.8.0.4 + rev: v0.11.0.1 hooks: - id: shellcheck - repo: local From 3c1c9aaa8f5ade110fb13ed027017a8710924662 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 22 Oct 2025 11:18:04 +0100 Subject: [PATCH 32/44] Initial attempt at a static version of df.py --- datafaker/populate.py | 189 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 datafaker/populate.py diff --git a/datafaker/populate.py b/datafaker/populate.py new file mode 100644 index 00000000..79114027 --- /dev/null +++ b/datafaker/populate.py @@ -0,0 +1,189 @@ +"""This file was auto-generated by datafaker but can be edited manually.""" +from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from mimesis import Generic, Numeric, Person +from mimesis.locales import Locale +import sqlalchemy +import sys +from typing import Any, Callable +import yaml + +from datafaker.base import FileUploader, TableGenerator, DistributionGenerator, ColumnPresence +from datafaker.main import load_metadata +from datafaker.make import TableGeneratorInfo, StoryGeneratorInfo #TODO: move these in here! + +from datafaker.providers import ( + BytesProvider, + ColumnValueProvider, + NullProvider, + SQLGroupByProvider, + TimedeltaProvider, + TimespanProvider, + WeightedBooleanProvider, +) +from datafaker.utils import logging, get_vocabulary_table_names + +generic = Generic(locale=Locale.EN_GB) +numeric = Numeric() +person = Person() +dist_gen = DistributionGenerator() +column_presence = ColumnPresence() + +generic.add_provider(BytesProvider) +generic.add_provider(ColumnValueProvider) +generic.add_provider(NullProvider) +generic.add_provider(SQLGroupByProvider) +generic.add_provider(TimedeltaProvider) +generic.add_provider(TimespanProvider) +generic.add_provider(WeightedBooleanProvider) + +#metadata = load_metadata("{{ orm_file_name }}", "{{ config_file_name }}") + +#import {{ row_generator_module_name }} +#import {{ story_generator_module_name }} + +def _eval_structure(config: Any, context: Mapping) -> Any: + """ + Turn a structure from ``config.yaml`` into a Python object. + + :param config: a structure (list, dict, number or expression in a string). + :return: Object matching the structure of ``config`` with strings eval'ed. + """ + if isinstance(config, str): + return eval(config, locals=context) + if isinstance(config, Mapping): + return { + k: _eval_structure(v, context) + for k, v in config.items() + } + if isinstance(config, Sequence): + return [_eval_structure(v, context) for v in config] + return config + + +def _get_object(class_name: Any, context: Mapping) -> Any: + """ + Get an object out of the context. + + :param class_name: The name of the class, qualified if necessary. + Like "module.MyClass.Nested" + :param context: Mapping of strings to objects with those names. + :return: A value from ``context`` if there are no qualifying names, + otherwise the attribute of the base object. + """ + if not isinstance(class_name, str): + return None + if not isinstance(kwargs, Mapping): + kwargs = {} + parts = class_name.split(".") + if parts[0] not in context: + logging.error('No such object "%"', parts[0]) + return None + value = context[parts[0]] + so_far = parts[0] + for part in parts[1:]: + so_far += "." + part + if not hasattr(value, part): + logging.error('No such attribute "%"', so_far) + return None + value = getattr(value, part) + return value + + +def _call_from_context(callable_name: Any, kwargs: Any, context: Mapping) -> Any: + """ + Call a callable from the classes (or functions) in the context. + + :param class_name: Possibly qualified name of class to construct. + :param context: Mapping of base classes and modules + :return: Constructed object, or None if this did not work. + """ + cls = _get_object(callable_name, context) + if not isinstance(cls, Callable): + return None + kws = _eval_structure(kwargs, context) + if kws is None: + return None + return cls(**kws) + + +def _get_src_stats(src_stats_filename: str) -> Any: + """ + Get the SRC_STATS object + """ + with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f: + return yaml.unsafe_load(f) + + +class TableGenerator: + + def __init__(self, rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, table_data: TableGeneratorInfo, max_unique_constraint_tries: int | None): + self.num_rows_per_pass = rows_per_pass + self.table_data = table_data + self.max_unique_constraint_tries = max_unique_constraint_tries + self.existing_constraint_hashes: MutableMapping[str, set[int]] = {} + self.context: Mapping = {} + for constraint in table_data.unique_constraints: + expr = sqlalchemy.select(constraint.columns) + query_result = dst_db_conn.execute(expr).fetchall() + self.existing_constraint_hashes[constraint.name] = set([ + hash(tuple(result)) + for result in query_result + ]) + + def set_context(self, context: Mapping): + self.context = context + + def __call__(self, dst_db_conn): + result = {} + columns_to_generate = set(self.table_data.nonnull_columns) + # Which missingness patterns do we want? + for choice in self.table_data.column_choices: + cols = _call_from_context(choice.function_name, choice.argument_values, self.context) + columns_to_generate.update(cols) + + max_tries = self.max_unique_constraint_tries + while columns_to_generate: + if max_tries == 0: + raise RuntimeError(f"Failed to satisfy unique constraints for table {self.table_data.table_name} after {self.max_unique_constraint_tries} attempts.") + if max_tries is not None: + max_tries -= 1 + for row_gen in self.table_data.row_gens: + if set(row_gen.variable_names) & columns_to_generate: + values = _call_from_context(row_gen.function_call.function_name, row_gen.function_call.argument_values, self.context) + for index, variable_name in enumerate(row_gen.variable_names): + result[variable_name] = values[index] + columns_to_generate = set() + for constraint in self.table_data.unique_constraints: + cf_hash = hash(tuple( + result[col.name] for col in constraint.columns + )) + if cf_hash in self.existing_constraint_hashes[constraint.name]: + columns_to_generate.update(c.name for c in constraint.columns) + for constraint in self.table_data.unique_constraints: + cf_hash = hash(tuple( + result[col.name] for col in constraint.columns + )) + self.existing_constraint_hashes.add(cf_hash) + return result + +def get_table_generator_dict(self, rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, tables_data: Iterable[TableGeneratorInfo], max_unique_constraint_tries: int | None): + return { + "{{ table_data.table_name }}": TableGenerator(rows_per_pass, dst_db_conn, table_data, max_unique_constraint_tries) + for table_data in tables_data +} + + +def get_vocab_dict(config: Mapping, metadata: sqlalchemy.MetaData) -> Mapping[str, FileUploader]: { + name: FileUploader[metadata.tables[name]] + for name in get_vocabulary_table_names(config) +} + +def get_story_generator_list(story_generator_infos: Iterable[StoryGeneratorInfo], context: Mapping) -> list[Mapping]: + return [ + { + "function": _call_from_context(gen_data.function_call.function_name, gen_data.function_call.argument_values, context), + "num_stories_per_pass": {{ gen_data.num_stories_per_pass }}, + "name": "{{ gen_data.function_call.function_name }}", + } + for gen_data in story_generator_infos + ] From 6cde9fd4733e51c2457582e728179544b197e44a Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 17 Mar 2026 18:30:58 +0000 Subject: [PATCH 33/44] A few updates. --- datafaker/populate.py | 76 ++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/datafaker/populate.py b/datafaker/populate.py index 79114027..e94af994 100644 --- a/datafaker/populate.py +++ b/datafaker/populate.py @@ -1,14 +1,11 @@ -"""This file was auto-generated by datafaker but can be edited manually.""" from collections.abc import Iterable, Mapping, MutableMapping, Sequence -from mimesis import Generic, Numeric, Person +from mimesis import Generic from mimesis.locales import Locale import sqlalchemy -import sys from typing import Any, Callable import yaml -from datafaker.base import FileUploader, TableGenerator, DistributionGenerator, ColumnPresence -from datafaker.main import load_metadata +from datafaker.base import FileUploader, TableGenerator from datafaker.make import TableGeneratorInfo, StoryGeneratorInfo #TODO: move these in here! from datafaker.providers import ( @@ -23,10 +20,6 @@ from datafaker.utils import logging, get_vocabulary_table_names generic = Generic(locale=Locale.EN_GB) -numeric = Numeric() -person = Person() -dist_gen = DistributionGenerator() -column_presence = ColumnPresence() generic.add_provider(BytesProvider) generic.add_provider(ColumnValueProvider) @@ -36,10 +29,6 @@ generic.add_provider(TimespanProvider) generic.add_provider(WeightedBooleanProvider) -#metadata = load_metadata("{{ orm_file_name }}", "{{ config_file_name }}") - -#import {{ row_generator_module_name }} -#import {{ story_generator_module_name }} def _eval_structure(config: Any, context: Mapping) -> Any: """ @@ -110,16 +99,34 @@ def _get_src_stats(src_stats_filename: str) -> Any: """ Get the SRC_STATS object """ - with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f: - return yaml.unsafe_load(f) + with open(src_stats_filename, "r", encoding="utf-8") as f: + return yaml.load(f, yaml.SafeLoader) class TableGenerator: - def __init__(self, rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, table_data: TableGeneratorInfo, max_unique_constraint_tries: int | None): + def __init__( + self, + rows_per_pass: int, + dst_db_conn: sqlalchemy.Connection, + table_data: TableGeneratorInfo, + max_unique_constraint_tries: int | None, + ) -> None: + """ + Initialize a table generator. + + :param rows_per_pass: How many rows to add for each call to ``__call__``. + :param dst_db_conn: Connection to the destination database. + :param table_data: Configuration for this generator. + :param max_unique_constraint_tries: How many times to redo generation in + an attempt to satisfy uniqueness constraints. None means never stop, but + this could cause an infinite loop if there are no solutions, or very long + execution if there are few solutions with many constraints. + """ self.num_rows_per_pass = rows_per_pass self.table_data = table_data self.max_unique_constraint_tries = max_unique_constraint_tries + self.db_conn = dst_db_conn self.existing_constraint_hashes: MutableMapping[str, set[int]] = {} self.context: Mapping = {} for constraint in table_data.unique_constraints: @@ -130,11 +137,13 @@ def __init__(self, rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, table for result in query_result ]) - def set_context(self, context: Mapping): + def set_context(self, context: Mapping) -> None: + """Sets all the Python symbols that must be known to the configuration.""" self.context = context - def __call__(self, dst_db_conn): - result = {} + def __call__(self): + """Generate some rows of the relevant table in the database.""" + result: dict[str, Any] = {} columns_to_generate = set(self.table_data.nonnull_columns) # Which missingness patterns do we want? for choice in self.table_data.column_choices: @@ -166,24 +175,33 @@ def __call__(self, dst_db_conn): self.existing_constraint_hashes.add(cf_hash) return result -def get_table_generator_dict(self, rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, tables_data: Iterable[TableGeneratorInfo], max_unique_constraint_tries: int | None): +def get_table_generator_dict( + rows_per_pass: int, + dst_db_conn: sqlalchemy.Connection, + tables_data: Iterable[TableGeneratorInfo], + max_unique_constraint_tries: int | None, +): + """Get a dict of table names to row generators that generate rows for that table.""" return { - "{{ table_data.table_name }}": TableGenerator(rows_per_pass, dst_db_conn, table_data, max_unique_constraint_tries) - for table_data in tables_data -} + table_data.table_name: TableGenerator(rows_per_pass, dst_db_conn, table_data, max_unique_constraint_tries) + for table_data in tables_data + } -def get_vocab_dict(config: Mapping, metadata: sqlalchemy.MetaData) -> Mapping[str, FileUploader]: { - name: FileUploader[metadata.tables[name]] - for name in get_vocabulary_table_names(config) -} +def get_vocab_dict(config: Mapping, metadata: sqlalchemy.MetaData) -> Mapping[str, FileUploader]: + """Get a dict of table names to objects that can populate those tables from YAML files.""" + return { + name: FileUploader(metadata.tables[name]) + for name in get_vocabulary_table_names(config) + } def get_story_generator_list(story_generator_infos: Iterable[StoryGeneratorInfo], context: Mapping) -> list[Mapping]: + """Get a list of mappings describing story generators that must be run.""" return [ { "function": _call_from_context(gen_data.function_call.function_name, gen_data.function_call.argument_values, context), - "num_stories_per_pass": {{ gen_data.num_stories_per_pass }}, - "name": "{{ gen_data.function_call.function_name }}", + "num_stories_per_pass": gen_data.num_stories_per_pass, + "name": gen_data.function_call.function_name, } for gen_data in story_generator_infos ] From 00df596ff0d45c126b49438b83202b47a0e02148 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 19 Mar 2026 18:21:09 +0000 Subject: [PATCH 34/44] First test for create-data without intermediate file --- datafaker/base.py | 18 --- datafaker/create.py | 47 ++++++-- datafaker/main.py | 18 +-- datafaker/make.py | 105 +++++++++++++----- datafaker/populate.py | 160 +++++++++++++++++++++------ datafaker/utils.py | 4 +- examples/airbnb/final_ssg_example.py | 135 ---------------------- examples/airbnb/ssg_manual_edit.py | 131 ---------------------- tests/examples/empty.sql | 2 + tests/test_create.py | 47 +++++++- tests/test_dump.py | 2 - tests/test_functional.py | 61 +--------- tests/test_main.py | 3 - tests/utils.py | 4 +- 14 files changed, 309 insertions(+), 428 deletions(-) delete mode 100644 examples/airbnb/final_ssg_example.py delete mode 100644 examples/airbnb/ssg_manual_edit.py create mode 100644 tests/examples/empty.sql diff --git a/datafaker/base.py b/datafaker/base.py index 97d77992..ae2ad7b5 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -22,24 +22,6 @@ ) -class TableGenerator(ABC): - """Abstract base class for table generator classes.""" - - num_rows_per_pass: int = 1 - - @abstractmethod - def __call__(self, dst_db_conn: Connection, metadata: MetaData) -> dict[str, Any]: - """Return, as a dictionary, a new row for the table that we are generating. - - The only argument, `dst_db_conn`, should be a database connection to the - database to which the data is being written. Most generators won't use it, but - some do, and thus it's required by the interface. - - The return value should be a dictionary with column names as strings for keys, - and the values being the values for the new row. - """ - - @dataclass class FileUploader: """For uploading data files.""" diff --git a/datafaker/create.py b/datafaker/create.py index 2cd0ddeb..3f2d861e 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,7 +1,6 @@ """Functions and classes to create and populate the target database.""" -import pathlib from collections import Counter -from types import ModuleType +from pathlib import Path from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple from sqlalchemy import Connection, insert, inspect @@ -10,10 +9,19 @@ from sqlalchemy.orm import Session from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table -from datafaker.base import FileUploader, TableGenerator +from datafaker.base import FileUploader +from datafaker.make import get_generation_info +from datafaker.populate import ( + TableGenerator, + get_symbols, + get_table_generator_dict, + get_story_generator_list, + get_vocab_dict, +) from datafaker.settings import get_destination_dsn, get_destination_schema, get_settings from datafaker.utils import ( create_db_engine_dst, + get_property, get_sync_engine, get_vocabulary_table_names, logger, @@ -92,7 +100,7 @@ def create_db_vocab( metadata: MetaData, meta_dict: dict[str, Any], config: Mapping, - base_path: pathlib.Path = pathlib.Path("."), + base_path: Path = Path("."), ) -> list[str]: """ Load vocabulary tables from files. @@ -140,14 +148,16 @@ def create_db_vocab( def create_db_data( sorted_tables: Sequence[Table], - df_module: ModuleType, + config: Mapping[str, Any], + src_stats_filename: Path | None, num_passes: int, metadata: MetaData, ) -> RowCounts: """Connect to a database and populate it with data.""" return create_db_data_into( sorted_tables, - df_module, + config, + src_stats_filename, num_passes, get_destination_dsn(), get_destination_schema(), @@ -158,7 +168,8 @@ def create_db_data( # pylint: disable=too-many-arguments too-many-positional-arguments def create_db_data_into( sorted_tables: Sequence[Table], - df_module: ModuleType, + config: Mapping[str, Any], + src_stats_filename: Path | None, num_passes: int, db_dsn: str, schema_name: str | None, @@ -176,17 +187,31 @@ def create_db_data_into( :param num_passes: Number of passes to perform. :param db_dsn: Connection string for the destination database. :param schema_name: Destination schema name. + :param metadata: Destination database metadata. """ dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name)) - + gen_info = get_generation_info(metadata, config, Path("orm.blah"), Path("config.blah"), src_stats_filename) row_counts: Counter[str] = Counter() with dst_engine.connect() as dst_conn: + context = get_symbols( + gen_info.row_generator_module_name, + gen_info.story_generator_module_name, + get_property(config, "object_instantiation", dict, {}), + gen_info.src_stats_filename, + dst_conn, + metadata, + ) for _ in range(num_passes): row_counts += populate( dst_conn, sorted_tables, - df_module.table_generator_dict, - df_module.story_generator_list, + get_table_generator_dict( + dst_conn, + gen_info.tables, + gen_info.max_unique_constraint_tries, + context, + ), + get_story_generator_list(gen_info.story_generators, context), metadata, ) dst_engine.dispose() @@ -336,7 +361,7 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values(table_generator(dst_conn, metadata)) + stmt = insert(table).values(table_generator(dst_conn)) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() diff --git a/datafaker/main.py b/datafaker/main.py index ff768c61..e324f23a 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -147,15 +147,19 @@ def create_data( help="The name of the ORM yaml file", dir_okay=False, ), - df_file: str = Option( - DF_FILENAME, - help="The name of the generators file. Must be in the current working directory.", - dir_okay=False, - ), config_file: Optional[Path] = Option( CONFIG_FILENAME, help="The configuration file", ), + stats_file: Optional[Path] = Option( + None, + help=( + "Statistics file (output of make-stats); default is src-stats.yaml if the " + "config file references SRC_STATS, or None otherwise." + ), + show_default=False, + dir_okay=False, + ), num_passes: int = Option(1, help="Number of passes (rows or stories) to make"), ) -> None: """Populate the schema in the target directory with synthetic data. @@ -179,11 +183,11 @@ def create_data( logger.debug("Creating data.") config = read_config_file(config_file) if config_file is not None else {} orm_metadata = load_metadata_for_output(orm_file, config) - df_module = import_file(df_file) try: row_counts = create_db_data( sorted_non_vocabulary_tables(orm_metadata, config), - df_module, + config, + stats_file, num_passes, orm_metadata, ) diff --git a/datafaker/make.py b/datafaker/make.py index a8dd58d0..e8a24142 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -66,7 +66,8 @@ class FunctionCall: """Contains the df.py content related function calls.""" function_name: str - argument_values: list[str] + args: list[str] + kwargs: dict[str, str] @dataclass @@ -83,7 +84,8 @@ class ColumnChoice: """Choose columns based on a random number in [0,1).""" function_name: str - argument_values: list[str] + args: list[str] + kwargs: dict[str, str] def make_column_choices( @@ -100,7 +102,8 @@ def make_column_choices( return [ ColumnChoice( function_name=mg["name"], - argument_values=[f"{k}={v}" for k, v in mg.get("kwargs", {}).items()], + args=mg.get("args", []), + kwargs=mg.get("kwargs", {}), ) for mg in table_config.get("missingness_generators", []) if "name" in mg @@ -168,12 +171,11 @@ def _get_function_call( if keyword_arguments is None: keyword_arguments = {} - argument_values: list[str] = [str(value) for value in positional_arguments] - argument_values += [ - f"{key}={_render_value(value)}" for key, value in keyword_arguments.items() - ] - - return FunctionCall(function_name=function_name, argument_values=argument_values) + return FunctionCall( + function_name=function_name, + args=positional_arguments, + kwargs=keyword_arguments, + ) def _get_row_generator( @@ -580,13 +582,29 @@ def make_vocabulary_tables( ) -def make_table_generators( # pylint: disable=too-many-locals +@dataclass +class GenerationInfo: + """Information for the generation of all data.""" + provider_imports: list[str] + orm_file_name: Path + config_file_name: Path + row_generator_module_name: str | None + story_generator_module_name: str | None + object_instantiation: dict[str, dict] + src_stats_filename: Path | None + tables: list[TableGeneratorInfo] + vocabulary_tables: list[VocabularyTableGeneratorInfo] + story_generators: list[StoryGeneratorInfo] + max_unique_constraint_tries: int | None + + +def get_generation_info( # pylint: disable=too-many-locals metadata: MetaData, config: Mapping, orm_filename: Path, config_filename: Path, src_stats_filename: Optional[Path], -) -> str: +) -> GenerationInfo: """ Create datafaker generator classes. @@ -605,10 +623,16 @@ def make_table_generators( # pylint: disable=too-many-locals :return: A string that is a valid Python module, once written to file. """ - row_generator_module_name: str = config.get("row_generators_module", None) - story_generator_module_name = config.get("story_generators_module", None) - object_instantiation: dict[str, dict] = config.get("object_instantiation", {}) - tables_config = config.get("tables", {}) + row_generator_module_name = get_property( + config, "row_generators_module", str | None, None + ) + story_generator_module_name = get_property( + config, "story_generators_module", str | None, None + ) + object_instantiation = get_property( + config, "object_instantiation", dict, {} + ) + tables_config = get_property(config, "tables", dict, {}) tables: list[TableGeneratorInfo] = [] vocabulary_tables: list[VocabularyTableGeneratorInfo] = [] @@ -637,20 +661,47 @@ def make_table_generators( # pylint: disable=too-many-locals story_generators = _get_story_generators(config) - max_unique_constraint_tries = config.get("max-unique-constraint-tries", None) + max_unique_constraint_tries = get_property( + config, "max-unique-constraint-tries", str | None, None + ) + return GenerationInfo( + provider_imports=PROVIDER_IMPORTS, + orm_file_name=orm_filename, + config_file_name=config_filename, + row_generator_module_name=row_generator_module_name, + story_generator_module_name=story_generator_module_name, + object_instantiation=object_instantiation, + src_stats_filename=src_stats_filename, + tables=tables, + vocabulary_tables=vocabulary_tables, + story_generators=story_generators, + max_unique_constraint_tries=max_unique_constraint_tries, + ) + + +def make_table_generators( # pylint: disable=too-many-locals + metadata: MetaData, + config: Mapping, + orm_filename: Path, + config_filename: Path, + src_stats_filename: Optional[Path], +) -> str: + gi = get_generation_info( + metadata, config, orm_filename, config_filename, src_stats_filename + ) return generate_df_content( { - "provider_imports": PROVIDER_IMPORTS, - "orm_file_name": orm_filename, - "config_file_name": config_filename, - "row_generator_module_name": row_generator_module_name, - "story_generator_module_name": story_generator_module_name, - "object_instantiation": object_instantiation, - "src_stats_filename": src_stats_filename, - "tables": tables, - "vocabulary_tables": vocabulary_tables, - "story_generators": story_generators, - "max_unique_constraint_tries": max_unique_constraint_tries, + "provider_imports": gi.provider_imports, + "orm_file_name": gi.orm_file_name, + "config_file_name": gi.config_file_name, + "row_generator_module_name": gi.row_generator_module_name, + "story_generator_module_name": gi.story_generator_module_name, + "object_instantiation": gi.object_instantiation, + "src_stats_filename": gi.src_stats_filename, + "tables": gi.tables, + "vocabulary_tables": gi.vocabulary_tables, + "story_generators": gi.story_generators, + "max_unique_constraint_tries": gi.max_unique_constraint_tries, } ) diff --git a/datafaker/populate.py b/datafaker/populate.py index e94af994..077c2edf 100644 --- a/datafaker/populate.py +++ b/datafaker/populate.py @@ -1,23 +1,25 @@ from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from pathlib import Path from mimesis import Generic from mimesis.locales import Locale import sqlalchemy from typing import Any, Callable import yaml -from datafaker.base import FileUploader, TableGenerator +from datafaker.base import FileUploader, ColumnPresence from datafaker.make import TableGeneratorInfo, StoryGeneratorInfo #TODO: move these in here! from datafaker.providers import ( BytesProvider, ColumnValueProvider, + DistributionProvider, NullProvider, SQLGroupByProvider, TimedeltaProvider, TimespanProvider, WeightedBooleanProvider, ) -from datafaker.utils import logging, get_vocabulary_table_names +from datafaker.utils import logging, get_vocabulary_table_names, import_file generic = Generic(locale=Locale.EN_GB) @@ -38,7 +40,12 @@ def _eval_structure(config: Any, context: Mapping) -> Any: :return: Object matching the structure of ``config`` with strings eval'ed. """ if isinstance(config, str): - return eval(config, locals=context) + try: + return eval(config, None, context) + except SyntaxError as exc: + raise exc + except NameError as exc: + raise exc if isinstance(config, Mapping): return { k: _eval_structure(v, context) @@ -49,7 +56,7 @@ def _eval_structure(config: Any, context: Mapping) -> Any: return config -def _get_object(class_name: Any, context: Mapping) -> Any: +def _get_object(class_name: str, context: Mapping) -> Any: """ Get an object out of the context. @@ -59,26 +66,25 @@ def _get_object(class_name: Any, context: Mapping) -> Any: :return: A value from ``context`` if there are no qualifying names, otherwise the attribute of the base object. """ - if not isinstance(class_name, str): - return None - if not isinstance(kwargs, Mapping): - kwargs = {} parts = class_name.split(".") if parts[0] not in context: - logging.error('No such object "%"', parts[0]) - return None + raise ValueError('No such object "%"', parts[0]) value = context[parts[0]] so_far = parts[0] for part in parts[1:]: so_far += "." + part if not hasattr(value, part): - logging.error('No such attribute "%"', so_far) - return None + raise ValueError('No such attribute "%"', so_far) value = getattr(value, part) return value -def _call_from_context(callable_name: Any, kwargs: Any, context: Mapping) -> Any: +def _call_from_context( + callable_name: str, + args: list[Any], + kwargs: dict[str, Any], + context: Mapping +) -> Any: """ Call a callable from the classes (or functions) in the context. @@ -89,25 +95,83 @@ def _call_from_context(callable_name: Any, kwargs: Any, context: Mapping) -> Any cls = _get_object(callable_name, context) if not isinstance(cls, Callable): return None - kws = _eval_structure(kwargs, context) - if kws is None: - return None - return cls(**kws) + arg_objs = [ + _eval_structure(arg, context) + for arg in args + ] + kwarg_objs = { + k: _eval_structure(v, context) + for k, v in kwargs.items() + } + return cls(*arg_objs, **kwarg_objs) + + +def get_symbols( + row_generator_module_name: str | None, + story_generator_module_name: str | None, + object_instantiation: dict[str, dict[str, Any]] | None, + src_stats_filename: str | None, + dst_db_conn: sqlalchemy.Connection, + metadata: sqlalchemy.MetaData, +) -> dict[str, Any]: + """Get the symbols that may be referred to by various configuration settings.""" + symbols = { + "dst_db_conn": dst_db_conn, + "metadata": metadata, + "generic": generic, + "numeric": generic.numeric, + "person": generic.person, + "dist_gen": DistributionProvider(), + "column_presence": ColumnPresence(), + } + _get_symbol_import(symbols, row_generator_module_name) + _get_symbol_import(symbols, story_generator_module_name) + if object_instantiation: + _get_symbols_instantiation(symbols, object_instantiation) + if src_stats_filename: + with Path(src_stats_filename).open(encoding="utf-8") as fh: + symbols["SRC_STATS"] = yaml.load(fh, yaml.SafeLoader) + return symbols -def _get_src_stats(src_stats_filename: str) -> Any: +def _get_symbol_import(symbols: dict[str, Any], module_name: str | None) -> None: """ - Get the SRC_STATS object + Load a module and add it as a symbol. + + :param symbols: Dict to add the module to. + :param module_name: if None, nothing will be added to ``symbols``. + Otherwise the ``module_name`` module will be loaded and added as + ``symbols[module_name]``. """ - with open(src_stats_filename, "r", encoding="utf-8") as f: - return yaml.load(f, yaml.SafeLoader) + if module_name is None: + return + symbols[module_name] = import_file(module_name + ".py") + + +def _get_symbols_instantiation(symbols: dict[str, Any], objs: dict[str, Any]) -> None: + """ + Instantiate objects and add them to the ``symbols`` dictionary. + + :param symbols: Dict to add the new objects to; also the context for the + instantiations. + :param objs: Dict of names to instantiation configurations. The names are + the keys that will be added to ``symbols``, the values are each callable + named by ``objs[name]["class"]`` with the arguments provided by + ``objs[name]["kwargs"]`` (which is a dict of argument names to a + Python string of the value to pass to that argument, such as ``'0'`` for + the number zero or ``"hello"`` for the string "hello"). + """ + for name, inst in objs.items(): + clbl = inst.get("class", None) + kwargs = inst.get("kwargs", {}) + if isinstance(clbl, str) and isinstance(kwargs, dict): + symbols[name] = _call_from_context(clbl, kwargs, symbols) class TableGenerator: def __init__( self, - rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, table_data: TableGeneratorInfo, max_unique_constraint_tries: int | None, @@ -123,10 +187,8 @@ def __init__( this could cause an infinite loop if there are no solutions, or very long execution if there are few solutions with many constraints. """ - self.num_rows_per_pass = rows_per_pass self.table_data = table_data self.max_unique_constraint_tries = max_unique_constraint_tries - self.db_conn = dst_db_conn self.existing_constraint_hashes: MutableMapping[str, set[int]] = {} self.context: Mapping = {} for constraint in table_data.unique_constraints: @@ -137,17 +199,27 @@ def __init__( for result in query_result ]) + @property + def num_rows_per_pass(self): + """Get the number of rows this generator should produce relative to all the rest.""" + return self.table_data.rows_per_pass + + @property + def name(self): + """Get the name of the table whose rows we are generating.""" + return self.table_data.table_name + def set_context(self, context: Mapping) -> None: """Sets all the Python symbols that must be known to the configuration.""" self.context = context - def __call__(self): + def __call__(self, db_conn: sqlalchemy.Connection): """Generate some rows of the relevant table in the database.""" result: dict[str, Any] = {} columns_to_generate = set(self.table_data.nonnull_columns) # Which missingness patterns do we want? for choice in self.table_data.column_choices: - cols = _call_from_context(choice.function_name, choice.argument_values, self.context) + cols = _call_from_context(choice.function_name, choice.args, choice.kwargs, self.context) columns_to_generate.update(cols) max_tries = self.max_unique_constraint_tries @@ -158,9 +230,17 @@ def __call__(self): max_tries -= 1 for row_gen in self.table_data.row_gens: if set(row_gen.variable_names) & columns_to_generate: - values = _call_from_context(row_gen.function_call.function_name, row_gen.function_call.argument_values, self.context) - for index, variable_name in enumerate(row_gen.variable_names): - result[variable_name] = values[index] + values = _call_from_context( + row_gen.function_call.function_name, + row_gen.function_call.args, + row_gen.function_call.kwargs, + self.context, + ) + if len(row_gen.variable_names) == 1: + result[row_gen.variable_names[0]] = values + else: + for index, variable_name in enumerate(row_gen.variable_names): + result[variable_name] = values[index] columns_to_generate = set() for constraint in self.table_data.unique_constraints: cf_hash = hash(tuple( @@ -175,15 +255,33 @@ def __call__(self): self.existing_constraint_hashes.add(cf_hash) return result + +def _make_table_generator( + dst_db_conn: sqlalchemy.Connection, + table_data: TableGeneratorInfo, + max_unique_constraint_tries: int | None, + context: Mapping, +): + """Make a ``TableGenerator`` with context attached.""" + gen = TableGenerator(dst_db_conn, table_data, max_unique_constraint_tries) + gen.set_context(context) + return gen + + def get_table_generator_dict( - rows_per_pass: int, dst_db_conn: sqlalchemy.Connection, tables_data: Iterable[TableGeneratorInfo], max_unique_constraint_tries: int | None, + context: Mapping, ): """Get a dict of table names to row generators that generate rows for that table.""" return { - table_data.table_name: TableGenerator(rows_per_pass, dst_db_conn, table_data, max_unique_constraint_tries) + table_data.table_name: _make_table_generator( + dst_db_conn, + table_data, + max_unique_constraint_tries, + context, + ) for table_data in tables_data } diff --git a/datafaker/utils.py b/datafaker/utils.py index 9a1ffbcb..25504a17 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -404,7 +404,7 @@ def get_flag(maybe_dict: Any, key: Any) -> bool: return isinstance(maybe_dict, Mapping) and maybe_dict.get(key, False) -def get_property(maybe_dict: Any, key: Any, required_type: Type[T], default: T) -> T: +def get_property(maybe_dict: Any, key: Any, required_type: type[T], default: T) -> T: """ Get a specific property from a dict or a default if that does not exist. @@ -797,7 +797,7 @@ def generators_require_stats(config: Mapping) -> bool: :param config: ``config.yaml`` object. :return: True if any of the arguments for any of the generators - reference ``SRC_STATS``. + reference ``SRC_STATS``. """ ois = { f"object_instantiation.{k}": call diff --git a/examples/airbnb/final_ssg_example.py b/examples/airbnb/final_ssg_example.py deleted file mode 100644 index 1212bbc5..00000000 --- a/examples/airbnb/final_ssg_example.py +++ /dev/null @@ -1,135 +0,0 @@ -"""This file was auto-generated by sqlsynthgen but can be edited manually.""" -from mimesis import Generic -from mimesis.locales import Locale -from sqlsynthgen.base import FileUploader, TableGenerator -from sqlsynthgen.unique_generator import UniqueGenerator - -generic = Generic(locale=Locale.EN_GB) - -from sqlsynthgen.providers import BytesProvider - -generic.add_provider(BytesProvider) -from sqlsynthgen.providers import ColumnValueProvider - -generic.add_provider(ColumnValueProvider) -from sqlsynthgen.providers import NullProvider - -generic.add_provider(NullProvider) -from sqlsynthgen.providers import SQLGroupByProvider - -generic.add_provider(SQLGroupByProvider) -from sqlsynthgen.providers import TimedeltaProvider - -generic.add_provider(TimedeltaProvider) -from sqlsynthgen.providers import TimespanProvider - -generic.add_provider(TimespanProvider) -from sqlsynthgen.providers import WeightedBooleanProvider - -generic.add_provider(WeightedBooleanProvider) - -import orm -import airbnb_generators -import airbnb_generators - -import yaml - -with open("src-stats.yaml", "r", encoding="utf-8") as f: - SRC_STATS = yaml.unsafe_load(f) - -countries_vocab = FileUploader(orm.Countries.__table__) - - -class age_gender_bktsGenerator(TableGenerator): - num_rows_per_pass = 1 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - result["gender"] = generic.person.password() - result["age_bucket"] = generic.person.password() - result["country_destination"] = generic.column_value_provider.column_value( - dst_db_conn, orm.Countries, "country_destination" - ) - result["population_in_thousands"] = generic.numeric.integer_number() - result["year"] = generic.numeric.integer_number() - return result - - -class usersGenerator(TableGenerator): - num_rows_per_pass = 0 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - result["age"] = airbnb_generators.user_age_provider( - query_results=SRC_STATS["age_stats"] - ) - result["id"] = generic.person.password() - ( - result["date_account_created"], - result["date_first_booking"], - ) = airbnb_generators.user_dates_provider(generic=generic) - result["timestamp_first_active"] = generic.datetime.datetime() - result["gender"] = generic.text.color() - result["signup_method"] = generic.text.color() - result["signup_flow"] = generic.numeric.integer_number() - result["language"] = generic.text.color() - result["affiliate_channel"] = generic.text.color() - result["affiliate_provider"] = generic.text.color() - result["first_affiliate_tracked"] = generic.text.color() - result["signup_app"] = generic.text.color() - result["first_device_type"] = generic.text.color() - result["first_browser"] = generic.text.color() - result["country_destination"] = generic.column_value_provider.column_value( - dst_db_conn, orm.Countries, "country_destination" - ) - return result - - -class sessionsGenerator(TableGenerator): - num_rows_per_pass = 0 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - result["secs_elapsed"] = generic.numeric.integer_number(start=0, end=3600) - result["action"] = generic.choice(items=["show", "index", "personalize"]) - result["user_id"] = generic.column_value_provider.column_value( - dst_db_conn, orm.Users, "id" - ) - result["action_type"] = generic.text.color() - result["action_detail"] = generic.text.color() - result["device_type"] = generic.text.color() - return result - - -table_generator_dict = { - "age_gender_bkts": age_gender_bktsGenerator(), - "users": usersGenerator(), - "sessions": sessionsGenerator(), -} - - -vocab_dict = { - "countries": countries_vocab, -} - - -def run_airbnb_generators_sessions_story(dst_db_conn): - return airbnb_generators.sessions_story() - - -story_generator_list = [ - { - "function": run_airbnb_generators_sessions_story, - "num_stories_per_pass": 30, - "name": "airbnb_generators.sessions_story", - }, -] diff --git a/examples/airbnb/ssg_manual_edit.py b/examples/airbnb/ssg_manual_edit.py deleted file mode 100644 index dafbb4c8..00000000 --- a/examples/airbnb/ssg_manual_edit.py +++ /dev/null @@ -1,131 +0,0 @@ -"""This file was auto-generated by sqlsynthgen but can be edited manually.""" -from mimesis import Generic -from mimesis.locales import Locale -from sqlsynthgen.base import FileUploader, TableGenerator -from sqlsynthgen.unique_generator import UniqueGenerator - -generic = Generic(locale=Locale.EN_GB) - -from sqlsynthgen.providers import BytesProvider - -generic.add_provider(BytesProvider) -from sqlsynthgen.providers import ColumnValueProvider - -generic.add_provider(ColumnValueProvider) -from sqlsynthgen.providers import NullProvider - -generic.add_provider(NullProvider) -from sqlsynthgen.providers import SQLGroupByProvider - -generic.add_provider(SQLGroupByProvider) -from sqlsynthgen.providers import TimedeltaProvider - -generic.add_provider(TimedeltaProvider) -from sqlsynthgen.providers import TimespanProvider - -generic.add_provider(TimespanProvider) -from sqlsynthgen.providers import WeightedBooleanProvider - -generic.add_provider(WeightedBooleanProvider) - -import orm - - -class countriesGenerator(TableGenerator): - num_rows_per_pass = 1 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - result["country_destination"] = generic.person.password() # manual edit - result["lat_destination"] = generic.numeric.float_number() - result["lng_destination"] = generic.numeric.float_number() - result["distance_km"] = generic.numeric.float_number() - result["destination_km2"] = generic.numeric.integer_number() - result["destination_language"] = generic.text.color() - result["language_levenshtein_distance"] = generic.numeric.float_number() - return result - - -class age_gender_bktsGenerator(TableGenerator): - num_rows_per_pass = 1 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - # start manual edit - result["age_bucket"] = generic.person.password() - result["country_destination"] = ColumnValueProvider().column_value( - dst_db_conn, orm.Countries, "country_destination" - ) - result["gender"] = generic.person.password() - # end manual edit - result["population_in_thousands"] = generic.numeric.integer_number() - result["year"] = generic.numeric.integer_number() - return result - - -class usersGenerator(TableGenerator): - num_rows_per_pass = 1 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - result["id"] = generic.person.password() # manual edit - result["date_account_created"] = generic.datetime.date() - result["timestamp_first_active"] = generic.datetime.datetime() - result["date_first_booking"] = generic.datetime.date() - result["gender"] = generic.text.color() - result["age"] = generic.numeric.integer_number() - result["signup_method"] = generic.text.color() - result["signup_flow"] = generic.numeric.integer_number() - result["language"] = generic.text.color() - result["affiliate_channel"] = generic.text.color() - result["affiliate_provider"] = generic.text.color() - result["first_affiliate_tracked"] = generic.text.color() - result["signup_app"] = generic.text.color() - result["first_device_type"] = generic.text.color() - result["first_browser"] = generic.text.color() - result["country_destination"] = generic.column_value_provider.column_value( - dst_db_conn, orm.Countries, "country_destination" - ) - return result - - -class sessionsGenerator(TableGenerator): - num_rows_per_pass = 1 - - def __init__(self): - pass - - def __call__(self, dst_db_conn): - result = {} - result["user_id"] = generic.column_value_provider.column_value( - dst_db_conn, orm.Users, "id" - ) - result["action"] = generic.text.color() - result["action_type"] = generic.text.color() - result["action_detail"] = generic.text.color() - result["device_type"] = generic.text.color() - result["secs_elapsed"] = generic.numeric.float_number() - return result - - -table_generator_dict = { - "countries": countriesGenerator(), - "age_gender_bkts": age_gender_bktsGenerator(), - "users": usersGenerator(), - "sessions": sessionsGenerator(), -} - - -vocab_dict = {} - - -story_generator_list = [] diff --git a/tests/examples/empty.sql b/tests/examples/empty.sql new file mode 100644 index 00000000..0343e1b4 --- /dev/null +++ b/tests/examples/empty.sql @@ -0,0 +1,2 @@ +CREATE DATABASE empty WITH TEMPLATE template0 ENCODING = 'UTF8' LOCALE = 'en_US.utf8'; +ALTER DATABASE empty OWNER TO postgres; diff --git a/tests/test_create.py b/tests/test_create.py index 3ecdce39..9bd014e7 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -13,15 +13,17 @@ from sqlalchemy import Connection, Engine, select from sqlalchemy.schema import MetaData, Table -from datafaker.base import TableGenerator +from datafaker.populate import TableGenerator from datafaker.create import ( create_db_data_into, create_db_tables, + create_db_tables_into, create_db_vocab, populate, ) -from datafaker.serialize_metadata import metadata_to_dict -from tests.utils import DatafakerTestCase, GeneratesDBTestCase +from datafaker.serialize_metadata import metadata_to_dict, dict_to_metadata +from datafaker.utils import sorted_non_vocabulary_tables +from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase class TestCreate(GeneratesDBTestCase): @@ -306,6 +308,7 @@ def __call__( "duckdb:///:memory:data", None, MagicMock(), + MagicMock(), ) assert mock_populate.side_effect.called @@ -338,3 +341,41 @@ def __call__(self, connection: Connection, base_path: Path) -> None: base_path=Path("base"), ) assert file_uploader.return_value.load.called + + +class CreateDataTestCase(RequiresDBTestCase): + """Tests for create-data.""" + dump_file_path = "empty.sql" + database_name = "empty" + schema_name = "public" + + def test_create_data_minimal(self) -> None: + """Test creating one table with one PK column.""" + config = {} + orm = { + "tables": { + "one": { + "columns": { + "id": { + "primary": True, + "type": "INTEGER", + } + } + } + } + } + metadata = dict_to_metadata(orm, config) + create_db_tables_into(metadata, self.dsn, self.schema_name) + row_counts = create_db_data_into( + sorted_non_vocabulary_tables(metadata, config), + config, + None, + 4, + self.dsn, + self.schema_name, + metadata, + ) + with self.sync_engine.connect() as connection: + stmt = select(metadata.tables["one"]) + rows = connection.execute(stmt).fetchall() + self.assertListEqual(rows, [(1,), (2,), (3,), (4,)]) diff --git a/tests/test_dump.py b/tests/test_dump.py index d6812185..d1958272 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -175,8 +175,6 @@ def test_end_to_end_parquet(self) -> None: # Generate the fake data result = runner.invoke(app, ["create-tables"]) self.assertSuccess(result) - result = runner.invoke(app, ["create-generators"]) - self.assertSuccess(result) num_passes = 70 result = runner.invoke(app, ["create-data", "--num-passes", str(num_passes)]) self.assertSuccess(result) diff --git a/tests/test_functional.py b/tests/test_functional.py index eac67081..fcef8618 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -99,32 +99,6 @@ def test_workflow_minimal_args(self) -> None: ) self.assert_silent_success(completed_process) - completed_process = self.invoke( - "create-generators", - "--force", - "--stats-file=src-stats.yaml", - ) - self.assertNoException(completed_process) - self.assertEqual( - { - ( - "Unsupported SQLAlchemy type CIDR for column " - "column_with_unusual_type of table strange_type_table. " - "Setting this column to NULL always, you may want to " - "configure a row generator for it instead." - ), - ( - "Unsupported SQLAlchemy type BIT for column " - "column_with_unusual_type_and_length of table " - "strange_type_table. Setting this column to NULL always, " - "you may want to configure a row generator for it instead." - ), - }, - set(completed_process.stderr.split("\n")) - {""}, - ) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) - completed_process = self.invoke( "create-tables", ) @@ -141,7 +115,9 @@ def test_workflow_minimal_args(self) -> None: ) self.assert_silent_success(completed_process) - completed_process = self.invoke("create-data") + completed_process = self.invoke( + "create-data", + ) self.assertNoException(completed_process) self.assertEqual("", completed_process.stderr) self.assertSuccess(completed_process) @@ -249,33 +225,6 @@ def test_workflow_maximal_args(self) -> None: set(completed_process.stdout.split("\n")) - {""}, ) - completed_process = self.invoke( - "--verbose", - "create-generators", - f"--orm-file={self.alt_orm_file_path}", - f"--df-file={self.alt_datafaker_file_path}", - f"--config-file={self.config_file_path}", - f"--stats-file={self.stats_file_path}", - "--force", - ) - self.assertEqual( - "Unsupported SQLAlchemy type CIDR " - "for column column_with_unusual_type of table strange_type_table. " - "Setting this column to NULL always, " - "you may want to configure a row generator for it instead.\n" - "Unsupported SQLAlchemy type BIT " - "for column column_with_unusual_type_and_length of table " - "strange_type_table. Setting this column to NULL always, " - "you may want to configure a row generator for it instead.\n", - completed_process.stderr, - ) - self.assertSuccess(completed_process) - self.assertEqual( - f"Making {self.alt_datafaker_file_path}.\n" - f"{self.alt_datafaker_file_path} created.\n", - completed_process.stdout, - ) - completed_process = self.invoke( "--verbose", "create-tables", @@ -324,7 +273,7 @@ def test_workflow_maximal_args(self) -> None: "--verbose", "create-data", f"--orm-file={self.alt_orm_file_path}", - f"--df-file={self.alt_datafaker_file_path}", + f"--stats-file={self.stats_file_path}", f"--config-file={self.config_file_path}", "--num-passes=2", ) @@ -529,7 +478,7 @@ def test_unique_constraint_fail(self) -> None: "create-data", f"--config-file={self.config_file_path}", f"--orm-file={self.alt_orm_file_path}", - f"--df-file={self.alt_datafaker_file_path}", + f"--stats-file={self.stats_file_path}", "--num-passes=1", ) self.assertEqual("", completed_process.stderr) diff --git a/tests/test_main.py b/tests/test_main.py index 50d86027..a3f0d5da 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -191,7 +191,6 @@ def test_create_tables( @patch("datafaker.main.sorted_non_vocabulary_tables") @patch("datafaker.main.logger") - @patch("datafaker.main.import_file") @patch("datafaker.main.create_db_data") @patch("datafaker.main.load_metadata_for_output") # pylint: disable=too-many-arguments too-many-positional-arguments @@ -199,7 +198,6 @@ def test_create_data( self, mock_load_metadata: MagicMock, mock_create: MagicMock, - mock_import: MagicMock, mock_logger: MagicMock, mock_tables: MagicMock, ) -> None: @@ -215,7 +213,6 @@ def test_create_data( ], catch_exceptions=False, ) - self.assertListEqual([call("df.py")], mock_import.call_args_list) mock_create.assert_called_once_with( mock_tables.return_value, diff --git a/tests/utils.py b/tests/utils.py index 54155337..544011f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -486,10 +486,10 @@ def create_tables(self) -> None: def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff - datafaker_module = import_file(self.generators_file_path) create_db_data_into( sorted_non_vocabulary_tables(self.metadata, config), - datafaker_module, + config, + Path(self.stats_file_path), num_passes, self.dst_dsn, self.dst_schema_name, From ba9bf59a72b82a019eafbedcf20569672a98a52d Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 20 Mar 2026 20:07:36 +0000 Subject: [PATCH 35/44] test_workflow_minimal_args passes --- datafaker/create.py | 139 +++++++++------- datafaker/main.py | 33 +--- datafaker/make.py | 40 +---- datafaker/populate.py | 57 +++---- tests/test_create.py | 4 +- tests/test_functional.py | 15 -- ...test_interactive_generators_partitioned.py | 4 - tests/test_main.py | 157 +----------------- tests/utils.py | 16 +- 9 files changed, 121 insertions(+), 344 deletions(-) diff --git a/datafaker/create.py b/datafaker/create.py index 3f2d861e..d32c15db 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -2,6 +2,7 @@ from collections import Counter from pathlib import Path from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple +import yaml from sqlalchemy import Connection, insert, inspect from sqlalchemy.exc import IntegrityError @@ -12,10 +13,11 @@ from datafaker.base import FileUploader from datafaker.make import get_generation_info from datafaker.populate import ( + StoryGeneratorInfo, TableGenerator, + call_function, get_symbols, get_table_generator_dict, - get_story_generator_list, get_vocab_dict, ) from datafaker.settings import get_destination_dsn, get_destination_schema, get_settings @@ -154,10 +156,15 @@ def create_db_data( metadata: MetaData, ) -> RowCounts: """Connect to a database and populate it with data.""" + if src_stats_filename: + with src_stats_filename.open(encoding="utf-8") as fh: + src_stats = yaml.load(fh, yaml.SafeLoader) + else: + src_stats = None return create_db_data_into( sorted_tables, config, - src_stats_filename, + src_stats, num_passes, get_destination_dsn(), get_destination_schema(), @@ -169,7 +176,7 @@ def create_db_data( def create_db_data_into( sorted_tables: Sequence[Table], config: Mapping[str, Any], - src_stats_filename: Path | None, + src_stats: dict[str, dict[str, Any]] | None, num_passes: int, db_dsn: str, schema_name: str | None, @@ -190,17 +197,17 @@ def create_db_data_into( :param metadata: Destination database metadata. """ dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name)) - gen_info = get_generation_info(metadata, config, Path("orm.blah"), Path("config.blah"), src_stats_filename) + gen_info = get_generation_info(metadata, config) + context = get_symbols( + gen_info.row_generator_module_name, + gen_info.story_generator_module_name, + get_property(config, "object_instantiation", dict, {}), + src_stats, + metadata, + ) row_counts: Counter[str] = Counter() with dst_engine.connect() as dst_conn: - context = get_symbols( - gen_info.row_generator_module_name, - gen_info.story_generator_module_name, - get_property(config, "object_instantiation", dict, {}), - gen_info.src_stats_filename, - dst_conn, - metadata, - ) + context["dst_db_conn"] = dst_conn for _ in range(num_passes): row_counts += populate( dst_conn, @@ -211,8 +218,8 @@ def create_db_data_into( gen_info.max_unique_constraint_tries, context, ), - get_story_generator_list(gen_info.story_generators, context), - metadata, + gen_info.story_generators, + context, ) dst_engine.dispose() return row_counts @@ -224,24 +231,51 @@ class StoryIterator: def __init__( self, - stories: Iterable[tuple[str, Story]], + stories: Iterable[StoryGeneratorInfo], table_dict: Mapping[str, Table], table_generator_dict: Mapping[str, TableGenerator], dst_conn: Connection, + context: Mapping, ): """Initialise a Story Iterator.""" - self._stories: Iterator[tuple[str, Story]] = iter(stories) + self._story_infos: Iterator[StoryGeneratorInfo] = iter(stories) self._table_dict: Mapping[str, Table] = table_dict self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict self._dst_conn: Connection = dst_conn - self._table_name: str | None + self._table_name: str | None = None self._final_values: dict[str, Any] | None = None + # Number of times the current story should be run + self._story_counts = 1 + self._story_function_call = None + self._context = context + self._story = iter([]) + self.next() + + def _get_next_story(self) -> None: + """ + Iterate to the next ``_story_infos``. + + :return: False if there are no more. + """ try: - name, self._story = next(self._stories) - logger.info("Generating data for story '%s'", name) - self._table_name, self._provided_values = next(self._story) + sgi = next(self._story_infos) + self._story_counts = sgi.num_stories_per_pass + self._story_function_call = sgi.function_call + logger.info("Generating data for story '%s'", sgi.function_call.function_name) + self._story = call_function(sgi.function_call, self._context) + self._final_values = None except StopIteration: self._table_name = None + return False + return True + + def _get_values(self) -> None: + if self._final_values is None: + self._table_name, self._provided_values = next(self._story) + else: + self._table_name, self._provided_values = self._story.send( + self._final_values + ) def is_ended(self) -> bool: """ @@ -249,7 +283,8 @@ def is_ended(self) -> bool: If so, insert() can be called. """ - return self._table_name is None + return self._story_counts == -1 + def has_table(self, table_name: str) -> bool: """Check if we have a row for table ``table_name``.""" @@ -264,7 +299,7 @@ def table_name(self) -> str | None: """ return self._table_name - def insert(self, metadata: MetaData) -> None: + def insert(self) -> None: """ Put the row in the table. @@ -276,7 +311,7 @@ def insert(self, metadata: MetaData) -> None: table = self._table_dict[self._table_name] if table.name in self._table_generator_dict: table_generator = self._table_generator_dict[table.name] - default_values = table_generator(self._dst_conn, metadata) + default_values = table_generator(self._dst_conn) else: default_values = {} insert_values = {**default_values, **self._provided_values} @@ -300,20 +335,16 @@ def next(self) -> None: """Advance to the next row.""" while True: try: - if self._final_values is None: - self._table_name, self._provided_values = next(self._story) - return - self._table_name, self._provided_values = self._story.send( - self._final_values - ) + self._get_values() return - except StopIteration: - try: - name, self._story = next(self._stories) - logger.info("Generating data for story '%s'", name) - self._final_values = None - except StopIteration: - self._table_name = None + except StopIteration as exc: + self._final_values = None + self._story_counts -= 1 + if 0 < self._story_counts: + # Reinitialize the same story again + self._story = call_function(self._story_function_call, self._context) + elif not self._get_next_story(): + self._story_counts = -1 return @@ -321,8 +352,8 @@ def populate( dst_conn: Connection, tables: Sequence[Table], table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], - metadata: MetaData, + story_generator_infos: Sequence[StoryGeneratorInfo], + context: Mapping, ) -> RowCounts: """Populate a database schema with synthetic data.""" row_counts: Counter[str] = Counter() @@ -330,24 +361,20 @@ def populate( # Generate stories # Each story generator returns a python generator (an unfortunate naming clash with # what we call generators). Iterating over it yields individual rows for the - # database. First, collect all of the python generators into a single list. - stories: list[tuple[str, Story]] = sum( - [ - [ - (sg["name"], sg["function"](dst_conn)) - for _ in range(sg["num_stories_per_pass"]) - ] - for sg in story_generator_list - ], - [], + # database. + story_iterator = StoryIterator( + story_generator_infos, + table_dict, + table_generator_dict, + dst_conn, + context, ) - story_iterator = StoryIterator(stories, table_dict, table_generator_dict, dst_conn) # Generate individual rows, table by table. for table in tables: # Do we have a story row to enter into this table? if story_iterator.has_table(table.name): - story_iterator.insert(metadata) + story_iterator.insert() row_counts[table.name] = row_counts.get(table.name, 0) + 1 story_iterator.next() if table.name not in table_generator_dict: @@ -358,20 +385,20 @@ def populate( continue logger.debug("Generating data for table '%s'", table.name) # Run all the inserts for one table in a transaction - try: - with dst_conn.begin(): + with dst_conn.begin(): + try: for _ in range(table_generator.num_rows_per_pass): stmt = insert(table).values(table_generator(dst_conn)) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() - except: - dst_conn.rollback() - raise + except: + dst_conn.rollback() + raise # Insert any remaining stories while not story_iterator.is_ended(): - story_iterator.insert(metadata) + story_iterator.insert() t = story_iterator.table_name() if t is None: raise AssertionError( diff --git a/datafaker/main.py b/datafaker/main.py index e324f23a..03b61656 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -31,7 +31,6 @@ from datafaker.interactive.base import DbCmd from datafaker.make import ( make_src_stats, - make_table_generators, make_tables_file, make_vocabulary_tables, ) @@ -182,6 +181,8 @@ def create_data( """ logger.debug("Creating data.") config = read_config_file(config_file) if config_file is not None else {} + if stats_file is None and generators_require_stats(config): + stats_file = Path(STATS_FILENAME) orm_metadata = load_metadata_for_output(orm_file, config) try: row_counts = create_db_data( @@ -294,34 +295,8 @@ def create_generators( False, "--force", "-f", help="Overwrite any existing Python generators file." ), ) -> None: - """Make a datafaker file of generator classes. - - This CLI command takes an object relation model output by sqlcodegen and - returns a set of synthetic data generators for each attribute - - Example: - $ datafaker create-generators - """ - logger.debug("Making %s.", df_file) - - if not force: - _check_file_non_existence(df_file) - - generator_config = read_config_file(config_file) if config_file is not None else {} - if stats_file is None and generators_require_stats(generator_config): - stats_file = Path(STATS_FILENAME) - orm_metadata = load_metadata_for_output(orm_file, generator_config) - result: str = make_table_generators( - orm_metadata, - generator_config, - orm_file, - config_file, - stats_file, - ) - - df_file.write_text(result, encoding="utf-8") - - logger.debug("%s created.", df_file) + """Obsolete command.""" + logger.error("This command is deprecated; it does nothing.") @app.command() diff --git a/datafaker/make.py b/datafaker/make.py index e8a24142..6f9d4c16 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -63,7 +63,7 @@ class VocabularyTableGeneratorInfo: @dataclass class FunctionCall: - """Contains the df.py content related function calls.""" + """Which function to call with what.""" function_name: str args: list[str] @@ -586,24 +586,18 @@ def make_vocabulary_tables( class GenerationInfo: """Information for the generation of all data.""" provider_imports: list[str] - orm_file_name: Path - config_file_name: Path row_generator_module_name: str | None story_generator_module_name: str | None object_instantiation: dict[str, dict] - src_stats_filename: Path | None tables: list[TableGeneratorInfo] vocabulary_tables: list[VocabularyTableGeneratorInfo] story_generators: list[StoryGeneratorInfo] max_unique_constraint_tries: int | None -def get_generation_info( # pylint: disable=too-many-locals +def get_generation_info( metadata: MetaData, config: Mapping, - orm_filename: Path, - config_filename: Path, - src_stats_filename: Optional[Path], ) -> GenerationInfo: """ Create datafaker generator classes. @@ -666,12 +660,9 @@ def get_generation_info( # pylint: disable=too-many-locals ) return GenerationInfo( provider_imports=PROVIDER_IMPORTS, - orm_file_name=orm_filename, - config_file_name=config_filename, row_generator_module_name=row_generator_module_name, story_generator_module_name=story_generator_module_name, object_instantiation=object_instantiation, - src_stats_filename=src_stats_filename, tables=tables, vocabulary_tables=vocabulary_tables, story_generators=story_generators, @@ -679,33 +670,6 @@ def get_generation_info( # pylint: disable=too-many-locals ) -def make_table_generators( # pylint: disable=too-many-locals - metadata: MetaData, - config: Mapping, - orm_filename: Path, - config_filename: Path, - src_stats_filename: Optional[Path], -) -> str: - gi = get_generation_info( - metadata, config, orm_filename, config_filename, src_stats_filename - ) - return generate_df_content( - { - "provider_imports": gi.provider_imports, - "orm_file_name": gi.orm_file_name, - "config_file_name": gi.config_file_name, - "row_generator_module_name": gi.row_generator_module_name, - "story_generator_module_name": gi.story_generator_module_name, - "object_instantiation": gi.object_instantiation, - "src_stats_filename": gi.src_stats_filename, - "tables": gi.tables, - "vocabulary_tables": gi.vocabulary_tables, - "story_generators": gi.story_generators, - "max_unique_constraint_tries": gi.max_unique_constraint_tries, - } - ) - - def generate_df_content(template_context: Mapping[str, Any]) -> str: """Generate the content of the df.py file as a string.""" environment: Environment = Environment( diff --git a/datafaker/populate.py b/datafaker/populate.py index 077c2edf..42861a42 100644 --- a/datafaker/populate.py +++ b/datafaker/populate.py @@ -4,10 +4,9 @@ from mimesis.locales import Locale import sqlalchemy from typing import Any, Callable -import yaml from datafaker.base import FileUploader, ColumnPresence -from datafaker.make import TableGeneratorInfo, StoryGeneratorInfo #TODO: move these in here! +from datafaker.make import FunctionCall, TableGeneratorInfo, StoryGeneratorInfo from datafaker.providers import ( BytesProvider, @@ -106,17 +105,24 @@ def _call_from_context( return cls(*arg_objs, **kwarg_objs) +def call_function(fn: FunctionCall, context: Mapping) -> Any: + return _call_from_context( + fn.function_name, + fn.args, + fn.kwargs, + context, + ) + + def get_symbols( row_generator_module_name: str | None, story_generator_module_name: str | None, object_instantiation: dict[str, dict[str, Any]] | None, - src_stats_filename: str | None, - dst_db_conn: sqlalchemy.Connection, + src_stats: dict[str, dict[str, Any]] | None, metadata: sqlalchemy.MetaData, ) -> dict[str, Any]: """Get the symbols that may be referred to by various configuration settings.""" symbols = { - "dst_db_conn": dst_db_conn, "metadata": metadata, "generic": generic, "numeric": generic.numeric, @@ -128,9 +134,8 @@ def get_symbols( _get_symbol_import(symbols, story_generator_module_name) if object_instantiation: _get_symbols_instantiation(symbols, object_instantiation) - if src_stats_filename: - with Path(src_stats_filename).open(encoding="utf-8") as fh: - symbols["SRC_STATS"] = yaml.load(fh, yaml.SafeLoader) + if src_stats is not None: + symbols["SRC_STATS"] = src_stats return symbols @@ -165,7 +170,7 @@ def _get_symbols_instantiation(symbols: dict[str, Any], objs: dict[str, Any]) -> clbl = inst.get("class", None) kwargs = inst.get("kwargs", {}) if isinstance(clbl, str) and isinstance(kwargs, dict): - symbols[name] = _call_from_context(clbl, kwargs, symbols) + symbols[name] = _call_from_context(clbl, [], kwargs, symbols) class TableGenerator: @@ -191,13 +196,14 @@ def __init__( self.max_unique_constraint_tries = max_unique_constraint_tries self.existing_constraint_hashes: MutableMapping[str, set[int]] = {} self.context: Mapping = {} - for constraint in table_data.unique_constraints: - expr = sqlalchemy.select(constraint.columns) - query_result = dst_db_conn.execute(expr).fetchall() - self.existing_constraint_hashes[constraint.name] = set([ - hash(tuple(result)) - for result in query_result - ]) + with dst_db_conn.begin(): + for constraint in table_data.unique_constraints: + expr = sqlalchemy.select(constraint.columns) + query_result = dst_db_conn.execute(expr).fetchall() + self.existing_constraint_hashes[constraint.name] = set([ + hash(tuple(result)) + for result in query_result + ]) @property def num_rows_per_pass(self): @@ -230,10 +236,8 @@ def __call__(self, db_conn: sqlalchemy.Connection): max_tries -= 1 for row_gen in self.table_data.row_gens: if set(row_gen.variable_names) & columns_to_generate: - values = _call_from_context( - row_gen.function_call.function_name, - row_gen.function_call.args, - row_gen.function_call.kwargs, + values = call_function( + row_gen.function_call, self.context, ) if len(row_gen.variable_names) == 1: @@ -252,7 +256,7 @@ def __call__(self, db_conn: sqlalchemy.Connection): cf_hash = hash(tuple( result[col.name] for col in constraint.columns )) - self.existing_constraint_hashes.add(cf_hash) + self.existing_constraint_hashes[constraint.name].add(cf_hash) return result @@ -292,14 +296,3 @@ def get_vocab_dict(config: Mapping, metadata: sqlalchemy.MetaData) -> Mapping[st name: FileUploader(metadata.tables[name]) for name in get_vocabulary_table_names(config) } - -def get_story_generator_list(story_generator_infos: Iterable[StoryGeneratorInfo], context: Mapping) -> list[Mapping]: - """Get a list of mappings describing story generators that must be run.""" - return [ - { - "function": _call_from_context(gen_data.function_call.function_name, gen_data.function_call.argument_values, context), - "num_stories_per_pass": gen_data.num_stories_per_pass, - "name": gen_data.function_call.function_name, - } - for gen_data in story_generator_infos - ] diff --git a/tests/test_create.py b/tests/test_create.py index 9bd014e7..ff1df950 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -27,7 +27,7 @@ class TestCreate(GeneratesDBTestCase): - """Test the make_table_generators function.""" + """Test that we can create data.""" dump_file_path = "instrument.sql" database_name = "instrument" @@ -67,7 +67,7 @@ def test_create_vocab(self) -> None: self.assertEqual(rows[2].given_name, "Mus") self.assertEqual(rows[2].family_name, "Al-Said") - def test_make_table_generators(self) -> None: + def test_column_defaults_in_stories(self) -> None: """Test that we can handle column defaults in stories.""" random.seed(56) config: Mapping[str, Any] = {} diff --git a/tests/test_functional.py b/tests/test_functional.py index fcef8618..1cc4499d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -76,7 +76,6 @@ class DBFunctionalTestCase(DBFunctionalTestCaseBase): schema_name = "public" alt_orm_file_path = Path("my_orm.yaml") - alt_datafaker_file_path = Path("my_df.py") def test_workflow_minimal_args(self) -> None: """Test the recommended CLI workflow runs without errors.""" @@ -119,14 +118,10 @@ def test_workflow_minimal_args(self) -> None: "create-data", ) self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) self.assertSuccess(completed_process) self.assertEqual( - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" "Generating data for story 'story_generators.short_story'\n" "Generating data for story 'story_generators.full_row_story'\n" - "Generating data for story 'story_generators.long_story'\n" "Generating data for story 'story_generators.long_story'\n", completed_process.stdout, ) @@ -453,14 +448,6 @@ def test_unique_constraint_fail(self) -> None: f"--config-file={self.config_file_path}", "--force", ) - self.invoke( - "create-generators", - f"--orm-file={self.alt_orm_file_path}", - f"--df-file={self.alt_datafaker_file_path}", - f"--config-file={self.config_file_path}", - f"--stats-file={self.stats_file_path}", - "--force", - ) self.invoke( "create-tables", f"--orm-file={self.alt_orm_file_path}", @@ -496,7 +483,6 @@ def test_unique_constraint_fail(self) -> None: "create-data", f"--config-file={self.config_file_path}", f"--orm-file={self.alt_orm_file_path}", - f"--df-file={self.alt_datafaker_file_path}", "--num-passes=3", ) self.assertEqual("", completed_process.stderr) @@ -518,7 +504,6 @@ def test_unique_constraint_fail(self) -> None: "create-data", f"--config-file={self.config_file_path}", f"--orm-file={self.alt_orm_file_path}", - f"--df-file={self.alt_datafaker_file_path}", "--num-passes=1", expected_error=( "Failed to satisfy unique constraints for table unique_constraint_test" diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index fd26fd69..e896b98f 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -178,7 +178,6 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: gc.do_quit("") self.set_configuration(gc.config) self.get_src_stats(gc.config) - self.create_generators(gc.config) self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) @@ -293,7 +292,6 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No gc.do_quit("") self.set_configuration(gc.config) self.get_src_stats(gc.config) - self.create_generators(gc.config) self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) @@ -408,7 +406,6 @@ def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: gc.do_quit("") self.set_configuration(gc.config) self.get_src_stats(gc.config) - self.create_generators(gc.config) self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) @@ -442,7 +439,6 @@ def test_create_with_null_partitioned_grouped_sampled_tiny(self) -> None: gc.do_quit("") self.set_configuration(gc.config) self.get_src_stats(gc.config) - self.create_generators(gc.config) self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) diff --git a/tests/test_main.py b/tests/test_main.py index a3f0d5da..b05cac08 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -15,161 +15,9 @@ runner = CliRunner(mix_stderr=False) -class TestCliGeneratorOutput(DatafakerTestCase): - """Tests for the command-line interface.""" - - use_temporary_cwd = True - example_conf = "example_config.yaml" - copy_files = [example_conf, "orm.yaml"] - copy_from_directory = Path("examples") - - @patch("datafaker.main.read_config_file") - @patch("datafaker.main.load_metadata_for_output") - @patch("datafaker.settings.get_settings") - @patch("datafaker.main.make_table_generators") - @patch("datafaker.main.generators_require_stats") - # pylint: disable=too-many-positional-arguments,too-many-arguments - def test_create_generators( - self, - mock_require_stats: MagicMock, - mock_make: MagicMock, - mock_settings: MagicMock, - mock_load_meta: MagicMock, - mock_config: MagicMock, - ) -> None: - """Test the create-generators sub-command.""" - mock_require_stats.return_value = False - mock_make.return_value = "some text" - mock_settings.return_value.src_postges_dsn = "" - - result = runner.invoke( - app, - [ - "create-generators", - "--config-file", - self.example_conf, - ], - catch_exceptions=False, - ) - self.assertSuccess(result) - - mock_make.assert_called_once_with( - mock_load_meta.return_value, - mock_config.return_value, - Path("orm.yaml"), - Path(self.example_conf), - None, - ) - with Path("df.py").open(encoding="utf-8") as dfh: - self.assertEqual(dfh.read(), "some text") - - @patch("datafaker.main.read_config_file") - @patch("datafaker.main.load_metadata_for_output") - @patch("datafaker.settings.get_settings") - @patch("datafaker.main.make_table_generators") - @patch("datafaker.main.generators_require_stats") - # pylint: disable=too-many-positional-arguments,too-many-arguments - def test_create_generators_uses_default_stats_file_if_necessary( - self, - mock_require_stats: MagicMock, - mock_make: MagicMock, - mock_settings: MagicMock, - mock_load_meta: MagicMock, - mock_config: MagicMock, - ) -> None: - """Test the create-generators sub-command.""" - mock_require_stats.return_value = True - mock_make.return_value = "some text" - mock_settings.return_value.src_postges_dsn = "" - - result = runner.invoke( - app, - [ - "create-generators", - "--config-file", - self.example_conf, - ], - catch_exceptions=False, - ) - - mock_make.assert_called_once_with( - mock_load_meta.return_value, - mock_config.return_value, - Path("orm.yaml"), - Path(self.example_conf), - Path("src-stats.yaml"), - ) - self.assertSuccess(result) - with Path("df.py").open(encoding="utf-8") as dfh: - self.assertEqual(dfh.read(), "some text") - - @patch("datafaker.main.logger") - def test_create_generators_errors_if_file_exists( - self, - mock_logger: MagicMock, - ) -> None: - """Test the create-generators sub-command doesn't overwrite.""" - df_path = Path("df.py") - - with df_path.open(mode="w", encoding="utf-8") as dfh: - dfh.write("already exists!\n") - - result = runner.invoke( - app, - [ - "create-generators", - "--config-file", - self.example_conf, - ], - catch_exceptions=False, - ) - mock_logger.error.assert_called_once_with( - "%s should not already exist. Exiting...", - df_path, - ) - self.assertEqual(1, result.exit_code) - - class TestCLI(DatafakerTestCase): """Tests for the command-line interface.""" - @patch("datafaker.main.read_config_file") - @patch("datafaker.main.load_metadata_for_output") - @patch("datafaker.settings.get_settings") - @patch("datafaker.main.make_table_generators") - # pylint: disable=too-many-positional-arguments,too-many-arguments - def test_create_generators_with_force_enabled( - self, - mock_make: MagicMock, - mock_settings: MagicMock, - mock_load_meta: MagicMock, - mock_config: MagicMock, - ) -> None: - """Tests the create-generators sub-commands overwrite files when instructed.""" - - mock_make.return_value = "make result" - mock_settings.return_value.src_postges_dsn = "" - - for force_option in ["--force", "-f"]: - with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke( - app, - [ - "create-generators", - force_option, - ], - ) - - self.assertSuccess(result) - mock_make.assert_called_once_with( - mock_load_meta.return_value, - mock_config.return_value, - Path("orm.yaml"), - Path("config.yaml"), - None, - ) - mock_make.reset_mock() - @patch("datafaker.main.create_db_tables") @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") @@ -193,9 +41,11 @@ def test_create_tables( @patch("datafaker.main.logger") @patch("datafaker.main.create_db_data") @patch("datafaker.main.load_metadata_for_output") + @patch("datafaker.main.read_config_file") # pylint: disable=too-many-arguments too-many-positional-arguments def test_create_data( self, + mock_read_config: MagicMock, mock_load_metadata: MagicMock, mock_create: MagicMock, mock_logger: MagicMock, @@ -216,7 +66,8 @@ def test_create_data( mock_create.assert_called_once_with( mock_tables.return_value, - mock_import.return_value, + mock_read_config.return_value, + None, 1, mock_load_metadata.return_value, ) diff --git a/tests/utils.py b/tests/utils.py index 544011f1..21176236 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,7 +26,7 @@ from datafaker import settings from datafaker.create import create_db_data_into, create_db_tables_into from datafaker.interactive.base import DbCmd -from datafaker.make import make_src_stats, make_table_generators, make_tables_file +from datafaker.make import make_src_stats, make_tables_file from datafaker.utils import ( MaybeAsyncEngine, T, @@ -466,19 +466,6 @@ def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: stats_fh.write(yaml.dump(src_stats)) return src_stats - def create_generators(self, config: Mapping[str, Any]) -> None: - """``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py``""" - datafaker_content = make_table_generators( - self.metadata, - config, - Path(self.orm_file_path), - Path(self.config_file_path), - Path(self.stats_file_path), - ) - (generators_fd, self.generators_file_path) = mkstemp(".py", "dfgen_", text=True) - with os.fdopen(generators_fd, "w", encoding="utf-8") as datafaker_fh: - datafaker_fh.write(datafaker_content) - def create_tables(self) -> None: """Create tables in the output DB.""" create_db_tables_into(self.metadata, self.dst_dsn, self.dst_schema_name) @@ -505,7 +492,6 @@ def generate_data( """ self.set_configuration(config) src_stats = self.get_src_stats(config) - self.create_generators(config) self.create_tables() self.create_data(config, num_passes) return src_stats From edda8354b947118710b43bc35d2bebdcfaab95ff Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 23 Mar 2026 19:27:19 +0000 Subject: [PATCH 36/44] All tests pass. --- datafaker/create.py | 22 ++++++----- datafaker/make.py | 2 +- datafaker/populate.py | 49 +++++++++++++---------- datafaker/utils.py | 4 +- tests/test_create.py | 72 ++++++++++++++++++++++++++++------ tests/test_functional.py | 84 ++++++++++++++-------------------------- tests/utils.py | 9 ++++- 7 files changed, 143 insertions(+), 99 deletions(-) diff --git a/datafaker/create.py b/datafaker/create.py index d32c15db..785ebecd 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -9,16 +9,15 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table +import typer from datafaker.base import FileUploader -from datafaker.make import get_generation_info +from datafaker.make import get_generation_info, StoryGeneratorInfo from datafaker.populate import ( - StoryGeneratorInfo, TableGenerator, call_function, get_symbols, get_table_generator_dict, - get_vocab_dict, ) from datafaker.settings import get_destination_dsn, get_destination_schema, get_settings from datafaker.utils import ( @@ -157,8 +156,15 @@ def create_db_data( ) -> RowCounts: """Connect to a database and populate it with data.""" if src_stats_filename: - with src_stats_filename.open(encoding="utf-8") as fh: - src_stats = yaml.load(fh, yaml.SafeLoader) + try: + with src_stats_filename.open(encoding="utf-8") as fh: + src_stats = yaml.load(fh, yaml.SafeLoader) + except FileNotFoundError: + logger.error( + "No source stats file '%', this should be the output of the 'make-stats' command", + src_stats_filename, + ) + raise typer.Exit(1) else: src_stats = None return create_db_data_into( @@ -187,10 +193,8 @@ def create_db_data_into( :param sorted_tables: The table names to populate, sorted so that foreign keys' targets are populated before the foreign keys themselves. - :param table_generator_dict: A mapping of table names to the generators - used to make data for them. - :param story_generator_list: A list of story generators to be run after the - table generators on each pass. + :param config: The data from the ``config.yaml`` file. + :param src_stats: The data from the ``src-stats.yaml`` file. :param num_passes: Number of passes to perform. :param db_dsn: Connection string for the destination database. :param schema_name: Destination schema name. diff --git a/datafaker/make.py b/datafaker/make.py index 6f9d4c16..d04d0476 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -656,7 +656,7 @@ def get_generation_info( story_generators = _get_story_generators(config) max_unique_constraint_tries = get_property( - config, "max-unique-constraint-tries", str | None, None + config, "max-unique-constraint-tries", int | None, None ) return GenerationInfo( provider_imports=PROVIDER_IMPORTS, diff --git a/datafaker/populate.py b/datafaker/populate.py index 42861a42..1a5e006a 100644 --- a/datafaker/populate.py +++ b/datafaker/populate.py @@ -6,7 +6,7 @@ from typing import Any, Callable from datafaker.base import FileUploader, ColumnPresence -from datafaker.make import FunctionCall, TableGeneratorInfo, StoryGeneratorInfo +from datafaker.make import FunctionCall, TableGeneratorInfo from datafaker.providers import ( BytesProvider, @@ -18,17 +18,34 @@ TimespanProvider, WeightedBooleanProvider, ) -from datafaker.utils import logging, get_vocabulary_table_names, import_file +from datafaker.utils import get_vocabulary_table_names, import_file + +def make_generic(): + g = Generic(locale=Locale.EN_GB) + g.add_providers( + BytesProvider, + ColumnValueProvider, + DistributionProvider, + NullProvider, + SQLGroupByProvider, + TimedeltaProvider, + TimespanProvider, + WeightedBooleanProvider, + ) + return g + + +generic = make_generic() -generic = Generic(locale=Locale.EN_GB) -generic.add_provider(BytesProvider) -generic.add_provider(ColumnValueProvider) -generic.add_provider(NullProvider) -generic.add_provider(SQLGroupByProvider) -generic.add_provider(TimedeltaProvider) -generic.add_provider(TimespanProvider) -generic.add_provider(WeightedBooleanProvider) +def reset_generic(): + """ + Reset all the generators. + + Only really useful in test code. + """ + global generic + generic = make_generic() def _eval_structure(config: Any, context: Mapping) -> Any: @@ -57,10 +74,10 @@ def _eval_structure(config: Any, context: Mapping) -> Any: def _get_object(class_name: str, context: Mapping) -> Any: """ - Get an object out of the context. + Fetch an object from the context. :param class_name: The name of the class, qualified if necessary. - Like "module.MyClass.Nested" + Like "module.MyClass.Nested" :param context: Mapping of strings to objects with those names. :return: A value from ``context`` if there are no qualifying names, otherwise the attribute of the base object. @@ -288,11 +305,3 @@ def get_table_generator_dict( ) for table_data in tables_data } - - -def get_vocab_dict(config: Mapping, metadata: sqlalchemy.MetaData) -> Mapping[str, FileUploader]: - """Get a dict of table names to objects that can populate those tables from YAML files.""" - return { - name: FileUploader(metadata.tables[name]) - for name in get_vocabulary_table_names(config) - } diff --git a/datafaker/utils.py b/datafaker/utils.py index 25504a17..d33cf751 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -546,9 +546,7 @@ def get_row_generators( :param table_config: The element from the ``tables:`` stanza of ``config.xml``. :return: Pair of (name, row generator config). """ - rgs = table_config.get("row_generators", None) - if isinstance(rgs, str) or not hasattr(rgs, "__iter__"): - return + rgs = get_property(table_config, "row_generators", list, []) for rg in rgs: name = rg.get("name", None) if name: diff --git a/tests/test_create.py b/tests/test_create.py index ff1df950..f18bc994 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -21,6 +21,7 @@ create_db_vocab, populate, ) +from datafaker.make import FunctionCall, StoryGeneratorInfo from datafaker.serialize_metadata import metadata_to_dict, dict_to_metadata from datafaker.utils import sorted_non_vocabulary_tables from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase @@ -99,7 +100,8 @@ class TestPopulate(DatafakerTestCase): """Test create.populate.""" # pylint: disable=too-many-locals - def test_populate(self) -> None: + @patch("datafaker.populate._get_object") + def test_populate(self, mock_get_object: MagicMock) -> None: """Test the populate function.""" table_name = "table_name" @@ -107,9 +109,7 @@ def story() -> Generator[Tuple[str, dict], None, None]: """Mock story.""" yield table_name, {} - def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: - """A function that returns mock stories.""" - return story() + mock_get_object.return_value = story for num_stories_per_pass, num_rows_per_pass, num_initial_rows in itt.product( [0, 2], [0, 3], [0, 17] @@ -130,11 +130,11 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: story_generators: list[dict[str, Any]] = ( [ - { - "function": mock_story_gen, - "num_stories_per_pass": num_stories_per_pass, - "name": "mock_story_gen", - } + StoryGeneratorInfo( + "mock_story_gen name", + FunctionCall("mock_story_gen", [], {}), + num_stories_per_pass, + ) ] if num_stories_per_pass > 0 else [] @@ -304,11 +304,11 @@ def __call__( create_db_data_into( [MagicMock()], MagicMock(), + None, 1, "duckdb:///:memory:data", None, MagicMock(), - MagicMock(), ) assert mock_populate.side_effect.called @@ -366,11 +366,12 @@ def test_create_data_minimal(self) -> None: } metadata = dict_to_metadata(orm, config) create_db_tables_into(metadata, self.dsn, self.schema_name) + generate_count = 4 row_counts = create_db_data_into( sorted_non_vocabulary_tables(metadata, config), config, None, - 4, + generate_count, self.dsn, self.schema_name, metadata, @@ -379,3 +380,52 @@ def test_create_data_minimal(self) -> None: stmt = select(metadata.tables["one"]) rows = connection.execute(stmt).fetchall() self.assertListEqual(rows, [(1,), (2,), (3,), (4,)]) + self.assertListEqual(list(row_counts.keys()), ['one']) + self.assertEqual(row_counts["one"], generate_count) + + def test_unique_constraint_minimal(self) -> None: + config = { + "tables": { + "one": { + "row_generators": [{ + "name": "dist_gen.constant", + "kwargs": { + "value": 123, + }, + "columns_assigned": ["tiger"], + }] + } + }, + "max-unique-constraint-tries": 20, + } + orm = { + "tables": { + "one": { + "columns": { + "id": { + "primary": True, + "type": "INTEGER", + }, + "tiger": { + "type": "INTEGER", + }, + }, + "unique": [ + {"name": "tiger_uniq", "columns": ["tiger"]} + ] + } + } + } + metadata = dict_to_metadata(orm, config) + create_db_tables_into(metadata, self.dsn, self.schema_name) + self.assertRaises( + RuntimeError, + create_db_data_into, + sorted_non_vocabulary_tables(metadata, config), + config, + None, + 2, + self.dsn, + self.schema_name, + metadata, + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 1cc4499d..a77f9c96 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -272,52 +272,34 @@ def test_workflow_maximal_args(self) -> None: f"--config-file={self.config_file_path}", "--num-passes=2", ) - self.assertEqual("", completed_process.stderr) - self.assertEqual( - sorted( - [ - "Creating data.", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.full_row_story'", - "Generating data for story 'story_generators.full_row_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for table 'data_type_test'", - "Generating data for table 'data_type_test'", - "Generating data for table 'no_pk_test'", - "Generating data for table 'no_pk_test'", - "Generating data for table 'person'", - "Generating data for table 'person'", - "Generating data for table 'strange_type_table'", - "Generating data for table 'strange_type_table'", - "Generating data for table 'unique_constraint_test'", - "Generating data for table 'unique_constraint_test'", - "Generating data for table 'unique_constraint_test2'", - "Generating data for table 'unique_constraint_test2'", - "Generating data for table 'test_entity'", - "Generating data for table 'test_entity'", - "Generating data for table 'hospital_visit'", - "Generating data for table 'hospital_visit'", - "Data created in 2 passes.", - f"person: {2*(3+1+2+2)} rows created.", - f"hospital_visit: {2*(2*2+3)} rows created.", - "data_type_test: 2 rows created.", - "no_pk_test: 2 rows created.", - "strange_type_table: 2 rows created.", - "unique_constraint_test: 2 rows created.", - "unique_constraint_test2: 2 rows created.", - "test_entity: 2 rows created.", - "", - ] - ), - sorted(completed_process.stdout.split("\n")), + self.assertSetEqual( + { + "Creating data.", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for table 'data_type_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'person'", + "Generating data for table 'person'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'test_entity'", + "Generating data for table 'hospital_visit'", + "Data created in 2 passes.", + f"person: {2*(3+1+2+2)} rows created.", + f"hospital_visit: {2*(2*2+3)} rows created.", + "data_type_test: 2 rows created.", + "no_pk_test: 2 rows created.", + "strange_type_table: 2 rows created.", + "unique_constraint_test: 2 rows created.", + "unique_constraint_test2: 2 rows created.", + "test_entity: 2 rows created.", + "", + }, + set(completed_process.stdout.split("\n")), ) completed_process = self.invoke( @@ -468,13 +450,9 @@ def test_unique_constraint_fail(self) -> None: f"--stats-file={self.stats_file_path}", "--num-passes=1", ) - self.assertEqual("", completed_process.stderr) self.assertEqual( - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" "Generating data for story 'story_generators.short_story'\n" "Generating data for story 'story_generators.full_row_story'\n" - "Generating data for story 'story_generators.long_story'\n" "Generating data for story 'story_generators.long_story'\n", completed_process.stdout, ) @@ -483,17 +461,14 @@ def test_unique_constraint_fail(self) -> None: "create-data", f"--config-file={self.config_file_path}", f"--orm-file={self.alt_orm_file_path}", + f"--stats-file={self.stats_file_path}", "--num-passes=3", ) - self.assertEqual("", completed_process.stderr) self.assertEqual( ( - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" "Generating data for story 'story_generators.short_story'\n" "Generating data for story 'story_generators.full_row_story'\n" "Generating data for story 'story_generators.long_story'\n" - "Generating data for story 'story_generators.long_story'\n" ) * 3, completed_process.stdout, @@ -504,6 +479,7 @@ def test_unique_constraint_fail(self) -> None: "create-data", f"--config-file={self.config_file_path}", f"--orm-file={self.alt_orm_file_path}", + f"--stats-file={self.stats_file_path}", "--num-passes=1", expected_error=( "Failed to satisfy unique constraints for table unique_constraint_test" diff --git a/tests/utils.py b/tests/utils.py index 21176236..b07aadae 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,6 +27,7 @@ from datafaker.create import create_db_data_into, create_db_tables_into from datafaker.interactive.base import DbCmd from datafaker.make import make_src_stats, make_tables_file +from datafaker.populate import reset_generic from datafaker.utils import ( MaybeAsyncEngine, T, @@ -244,6 +245,7 @@ def setUp(self) -> None: """Set up the test case with an actual orm.yaml file.""" super().setUp() settings.get_settings.cache_clear() + reset_generic() if self.use_temporary_cwd: self.start_dir = os.getcwd() self.working_dir = mkdtemp("test") @@ -473,10 +475,15 @@ def create_tables(self) -> None: def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff + if self.stats_file_path is None: + src_stats = None + else: + with Path(self.stats_file_path).open(encoding="utf-8") as fh: + src_stats = yaml.load(fh, yaml.SafeLoader) create_db_data_into( sorted_non_vocabulary_tables(self.metadata, config), config, - Path(self.stats_file_path), + src_stats, num_passes, self.dst_dsn, self.dst_schema_name, From c1c285870f6c864398a2814ab5a2f43849b87dce Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 23 Mar 2026 19:55:44 +0000 Subject: [PATCH 37/44] A few pre-commit fixes --- datafaker/base.py | 8 +++--- datafaker/create.py | 17 +++++++------ datafaker/dump.py | 3 +-- datafaker/main.py | 6 +---- datafaker/make.py | 9 +++---- datafaker/populate.py | 57 +++++++++++++++++-------------------------- tests/test_create.py | 31 +++++++++++------------ 7 files changed, 58 insertions(+), 73 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index ae2ad7b5..0c306dec 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -1,15 +1,15 @@ """Base table generator classes.""" +import gzip +import os +import random from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass -import gzip from io import TextIOWrapper -import os from pathlib import Path -import random from typing import Any -import yaml +import yaml from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.schema import MetaData, Table diff --git a/datafaker/create.py b/datafaker/create.py index 785ebecd..eb032119 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -2,17 +2,17 @@ from collections import Counter from pathlib import Path from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple -import yaml +import typer +import yaml from sqlalchemy import Connection, insert, inspect from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table -import typer from datafaker.base import FileUploader -from datafaker.make import get_generation_info, StoryGeneratorInfo +from datafaker.make import StoryGeneratorInfo, get_generation_info from datafaker.populate import ( TableGenerator, call_function, @@ -255,7 +255,7 @@ def __init__( self._story = iter([]) self.next() - def _get_next_story(self) -> None: + def _get_next_story(self) -> bool: """ Iterate to the next ``_story_infos``. @@ -265,7 +265,9 @@ def _get_next_story(self) -> None: sgi = next(self._story_infos) self._story_counts = sgi.num_stories_per_pass self._story_function_call = sgi.function_call - logger.info("Generating data for story '%s'", sgi.function_call.function_name) + logger.info( + "Generating data for story '%s'", sgi.function_call.function_name + ) self._story = call_function(sgi.function_call, self._context) self._final_values = None except StopIteration: @@ -289,7 +291,6 @@ def is_ended(self) -> bool: """ return self._story_counts == -1 - def has_table(self, table_name: str) -> bool: """Check if we have a row for table ``table_name``.""" return table_name == self._table_name @@ -346,7 +347,9 @@ def next(self) -> None: self._story_counts -= 1 if 0 < self._story_counts: # Reinitialize the same story again - self._story = call_function(self._story_function_call, self._context) + self._story = call_function( + self._story_function_call, self._context + ) elif not self._get_next_story(): self._story_counts = -1 return diff --git a/datafaker/dump.py b/datafaker/dump.py index c115b0a1..21aa7d63 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,10 +1,9 @@ """Data dumping functions.""" import csv import io -from typing import TYPE_CHECKING - from abc import ABC, abstractmethod from pathlib import Path +from typing import TYPE_CHECKING import pandas as pd import sqlalchemy diff --git a/datafaker/main.py b/datafaker/main.py index 03b61656..daffb56a 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -29,11 +29,7 @@ update_missingness, ) from datafaker.interactive.base import DbCmd -from datafaker.make import ( - make_src_stats, - make_tables_file, - make_vocabulary_tables, -) +from datafaker.make import make_src_stats, make_tables_file, make_vocabulary_tables from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.settings import ( SettingsError, diff --git a/datafaker/make.py b/datafaker/make.py index d04d0476..dbbc6f84 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -66,8 +66,8 @@ class FunctionCall: """Which function to call with what.""" function_name: str - args: list[str] - kwargs: dict[str, str] + args: list[Any] + kwargs: dict[str, Any] @dataclass @@ -585,6 +585,7 @@ def make_vocabulary_tables( @dataclass class GenerationInfo: """Information for the generation of all data.""" + provider_imports: list[str] row_generator_module_name: str | None story_generator_module_name: str | None @@ -623,9 +624,7 @@ def get_generation_info( story_generator_module_name = get_property( config, "story_generators_module", str | None, None ) - object_instantiation = get_property( - config, "object_instantiation", dict, {} - ) + object_instantiation = get_property(config, "object_instantiation", dict, {}) tables_config = get_property(config, "tables", dict, {}) tables: list[TableGeneratorInfo] = [] diff --git a/datafaker/populate.py b/datafaker/populate.py index 1a5e006a..3d13c5ed 100644 --- a/datafaker/populate.py +++ b/datafaker/populate.py @@ -1,13 +1,13 @@ from collections.abc import Iterable, Mapping, MutableMapping, Sequence from pathlib import Path +from typing import Any, Callable + +import sqlalchemy from mimesis import Generic from mimesis.locales import Locale -import sqlalchemy -from typing import Any, Callable -from datafaker.base import FileUploader, ColumnPresence +from datafaker.base import ColumnPresence, FileUploader from datafaker.make import FunctionCall, TableGeneratorInfo - from datafaker.providers import ( BytesProvider, ColumnValueProvider, @@ -20,7 +20,8 @@ ) from datafaker.utils import get_vocabulary_table_names, import_file -def make_generic(): + +def make_generic() -> Generic: g = Generic(locale=Locale.EN_GB) g.add_providers( BytesProvider, @@ -38,7 +39,7 @@ def make_generic(): generic = make_generic() -def reset_generic(): +def reset_generic() -> None: """ Reset all the generators. @@ -63,10 +64,7 @@ def _eval_structure(config: Any, context: Mapping) -> Any: except NameError as exc: raise exc if isinstance(config, Mapping): - return { - k: _eval_structure(v, context) - for k, v in config.items() - } + return {k: _eval_structure(v, context) for k, v in config.items()} if isinstance(config, Sequence): return [_eval_structure(v, context) for v in config] return config @@ -96,10 +94,7 @@ def _get_object(class_name: str, context: Mapping) -> Any: def _call_from_context( - callable_name: str, - args: list[Any], - kwargs: dict[str, Any], - context: Mapping + callable_name: str, args: list[Any], kwargs: dict[str, Any], context: Mapping ) -> Any: """ Call a callable from the classes (or functions) in the context. @@ -111,14 +106,8 @@ def _call_from_context( cls = _get_object(callable_name, context) if not isinstance(cls, Callable): return None - arg_objs = [ - _eval_structure(arg, context) - for arg in args - ] - kwarg_objs = { - k: _eval_structure(v, context) - for k, v in kwargs.items() - } + arg_objs = [_eval_structure(arg, context) for arg in args] + kwarg_objs = {k: _eval_structure(v, context) for k, v in kwargs.items()} return cls(*arg_objs, **kwarg_objs) @@ -191,7 +180,6 @@ def _get_symbols_instantiation(symbols: dict[str, Any], objs: dict[str, Any]) -> class TableGenerator: - def __init__( self, dst_db_conn: sqlalchemy.Connection, @@ -217,10 +205,9 @@ def __init__( for constraint in table_data.unique_constraints: expr = sqlalchemy.select(constraint.columns) query_result = dst_db_conn.execute(expr).fetchall() - self.existing_constraint_hashes[constraint.name] = set([ - hash(tuple(result)) - for result in query_result - ]) + self.existing_constraint_hashes[constraint.name] = set( + [hash(tuple(result)) for result in query_result] + ) @property def num_rows_per_pass(self): @@ -242,13 +229,17 @@ def __call__(self, db_conn: sqlalchemy.Connection): columns_to_generate = set(self.table_data.nonnull_columns) # Which missingness patterns do we want? for choice in self.table_data.column_choices: - cols = _call_from_context(choice.function_name, choice.args, choice.kwargs, self.context) + cols = _call_from_context( + choice.function_name, choice.args, choice.kwargs, self.context + ) columns_to_generate.update(cols) max_tries = self.max_unique_constraint_tries while columns_to_generate: if max_tries == 0: - raise RuntimeError(f"Failed to satisfy unique constraints for table {self.table_data.table_name} after {self.max_unique_constraint_tries} attempts.") + raise RuntimeError( + f"Failed to satisfy unique constraints for table {self.table_data.table_name} after {self.max_unique_constraint_tries} attempts." + ) if max_tries is not None: max_tries -= 1 for row_gen in self.table_data.row_gens: @@ -264,15 +255,11 @@ def __call__(self, db_conn: sqlalchemy.Connection): result[variable_name] = values[index] columns_to_generate = set() for constraint in self.table_data.unique_constraints: - cf_hash = hash(tuple( - result[col.name] for col in constraint.columns - )) + cf_hash = hash(tuple(result[col.name] for col in constraint.columns)) if cf_hash in self.existing_constraint_hashes[constraint.name]: columns_to_generate.update(c.name for c in constraint.columns) for constraint in self.table_data.unique_constraints: - cf_hash = hash(tuple( - result[col.name] for col in constraint.columns - )) + cf_hash = hash(tuple(result[col.name] for col in constraint.columns)) self.existing_constraint_hashes[constraint.name].add(cf_hash) return result diff --git a/tests/test_create.py b/tests/test_create.py index f18bc994..48b0e695 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -13,7 +13,6 @@ from sqlalchemy import Connection, Engine, select from sqlalchemy.schema import MetaData, Table -from datafaker.populate import TableGenerator from datafaker.create import ( create_db_data_into, create_db_tables, @@ -22,7 +21,8 @@ populate, ) from datafaker.make import FunctionCall, StoryGeneratorInfo -from datafaker.serialize_metadata import metadata_to_dict, dict_to_metadata +from datafaker.populate import TableGenerator +from datafaker.serialize_metadata import dict_to_metadata, metadata_to_dict from datafaker.utils import sorted_non_vocabulary_tables from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase @@ -128,7 +128,7 @@ def story() -> Generator[Tuple[str, dict], None, None]: {table_name: num_initial_rows} if num_initial_rows > 0 else {} ) - story_generators: list[dict[str, Any]] = ( + story_generators: list[StoryGeneratorInfo] = ( [ StoryGeneratorInfo( "mock_story_gen name", @@ -345,6 +345,7 @@ def __call__(self, connection: Connection, base_path: Path) -> None: class CreateDataTestCase(RequiresDBTestCase): """Tests for create-data.""" + dump_file_path = "empty.sql" database_name = "empty" schema_name = "public" @@ -379,21 +380,23 @@ def test_create_data_minimal(self) -> None: with self.sync_engine.connect() as connection: stmt = select(metadata.tables["one"]) rows = connection.execute(stmt).fetchall() - self.assertListEqual(rows, [(1,), (2,), (3,), (4,)]) - self.assertListEqual(list(row_counts.keys()), ['one']) + self.assertEqual(rows, [(1,), (2,), (3,), (4,)]) + self.assertListEqual(list(row_counts.keys()), ["one"]) self.assertEqual(row_counts["one"], generate_count) def test_unique_constraint_minimal(self) -> None: config = { "tables": { "one": { - "row_generators": [{ - "name": "dist_gen.constant", - "kwargs": { - "value": 123, - }, - "columns_assigned": ["tiger"], - }] + "row_generators": [ + { + "name": "dist_gen.constant", + "kwargs": { + "value": 123, + }, + "columns_assigned": ["tiger"], + } + ] } }, "max-unique-constraint-tries": 20, @@ -410,9 +413,7 @@ def test_unique_constraint_minimal(self) -> None: "type": "INTEGER", }, }, - "unique": [ - {"name": "tiger_uniq", "columns": ["tiger"]} - ] + "unique": [{"name": "tiger_uniq", "columns": ["tiger"]}], } } } From f914804b36feb7cbfdbf297e92c4e36adb4a0b1b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 24 Mar 2026 16:48:02 +0000 Subject: [PATCH 38/44] Cleaned pre-commit checks --- datafaker/base.py | 3 +- datafaker/create.py | 26 ++++++---- datafaker/dump.py | 1 - datafaker/generators/partitioned.py | 2 +- datafaker/main.py | 11 ++--- datafaker/make.py | 74 ++++++++++++----------------- datafaker/populate.py | 66 +++++++++++++------------ datafaker/serialize_metadata.py | 13 +++-- datafaker/utils.py | 39 ++++++++++++--- tests/test_create.py | 3 +- tests/test_functional.py | 2 - tests/test_rst.py | 14 ------ tests/utils.py | 3 -- 13 files changed, 130 insertions(+), 127 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 0c306dec..fdd0339e 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -2,7 +2,6 @@ import gzip import os import random -from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass from io import TextIOWrapper @@ -12,7 +11,7 @@ import yaml from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.schema import MetaData, Table +from sqlalchemy.schema import Table from datafaker.utils import ( MAKE_VOCAB_PROGRESS_REPORT_EVERY, diff --git a/datafaker/create.py b/datafaker/create.py index eb032119..f91dbb0c 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -12,7 +12,7 @@ from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table from datafaker.base import FileUploader -from datafaker.make import StoryGeneratorInfo, get_generation_info +from datafaker.make import FunctionCall, StoryGeneratorInfo, get_generation_info from datafaker.populate import ( TableGenerator, call_function, @@ -159,12 +159,12 @@ def create_db_data( try: with src_stats_filename.open(encoding="utf-8") as fh: src_stats = yaml.load(fh, yaml.SafeLoader) - except FileNotFoundError: + except FileNotFoundError as exc: logger.error( - "No source stats file '%', this should be the output of the 'make-stats' command", + "No source stats file '%s', this should be the output of the 'make-stats' command", src_stats_filename, ) - raise typer.Exit(1) + raise typer.Exit(1) from exc else: src_stats = None return create_db_data_into( @@ -205,7 +205,7 @@ def create_db_data_into( context = get_symbols( gen_info.row_generator_module_name, gen_info.story_generator_module_name, - get_property(config, "object_instantiation", dict, {}), + get_property(config, "object_instantiation", {}), src_stats, metadata, ) @@ -229,6 +229,14 @@ def create_db_data_into( return row_counts +def empty_story_generator() -> ( + Generator[tuple[str, dict[str, Any]], dict[str, Any], None] +): + """Get a story generator that generates no values.""" + empt: list[tuple[str, dict[str, Any]]] = [] + yield from empt + + # pylint: disable=too-many-instance-attributes class StoryIterator: """Iterates through all the rows produced by all the stories.""" @@ -250,9 +258,10 @@ def __init__( self._final_values: dict[str, Any] | None = None # Number of times the current story should be run self._story_counts = 1 - self._story_function_call = None + self._story_function_call: FunctionCall self._context = context - self._story = iter([]) + self._story = empty_story_generator() + self._provided_values: dict[str, Any] self.next() def _get_next_story(self) -> bool: @@ -276,6 +285,7 @@ def _get_next_story(self) -> bool: return True def _get_values(self) -> None: + """Get the values from the current story and advance the iterator.""" if self._final_values is None: self._table_name, self._provided_values = next(self._story) else: @@ -342,7 +352,7 @@ def next(self) -> None: try: self._get_values() return - except StopIteration as exc: + except StopIteration: self._final_values = None self._story_counts -= 1 if 0 < self._story_counts: diff --git a/datafaker/dump.py b/datafaker/dump.py index 21aa7d63..c4b1280e 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -3,7 +3,6 @@ import io from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING import pandas as pd import sqlalchemy diff --git a/datafaker/generators/partitioned.py b/datafaker/generators/partitioned.py index 493c2641..a7b9327f 100644 --- a/datafaker/generators/partitioned.py +++ b/datafaker/generators/partitioned.py @@ -445,7 +445,7 @@ def get_named_tables(self) -> Mapping[str, str]: def __init__(self, config: Mapping[str, Any]) -> None: """Initialize the null partitioned generator factory.""" - tables = get_property(config, "tables", dict, {}) + tables: dict[str, Any] = get_property(config, "tables", {}) self._named_tables = { table_name: table_conf["name_column"] for table_name, table_conf in tables.items() diff --git a/datafaker/main.py b/datafaker/main.py index daffb56a..eb6cc181 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -44,7 +44,6 @@ generated_tables, generators_require_stats, get_flag, - import_file, logger, read_config_file, sorted_non_vocabulary_tables, @@ -263,22 +262,22 @@ def create_tables( @app.command() def create_generators( - orm_file: Path = Option( + _orm_file: Path = Option( ORM_FILENAME, help="The name of the ORM yaml file", dir_okay=False, ), - df_file: Path = Option( + _df_file: Path = Option( DF_FILENAME, help="Path to write Python generators to.", dir_okay=False, ), - config_file: Path = Option( + _config_file: Path = Option( CONFIG_FILENAME, help="The configuration file", dir_okay=False, ), - stats_file: Optional[Path] = Option( + _stats_file: Optional[Path] = Option( None, help=( "Statistics file (output of make-stats); default is src-stats.yaml if the " @@ -287,7 +286,7 @@ def create_generators( show_default=False, dir_okay=False, ), - force: bool = Option( + _force: bool = Option( False, "--force", "-f", help="Overwrite any existing Python generators file." ), ) -> None: diff --git a/datafaker/make.py b/datafaker/make.py index dbbc6f84..2ab793b9 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -7,7 +7,7 @@ from datetime import datetime from pathlib import Path from types import TracebackType -from typing import Any, Final, Optional, Tuple, Type, Union +from typing import Any, Final, Optional, Tuple, Type import pandas as pd import snsql @@ -15,11 +15,17 @@ from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider -from sqlalchemy import CursorResult, Engine, MetaData, UniqueConstraint, text +from sqlalchemy import CursorResult, Engine, MetaData, text from sqlalchemy.dialects import postgresql from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine -from sqlalchemy.schema import Column, Table +from sqlalchemy.schema import ( + Column, + ColumnCollectionConstraint, + PrimaryKeyConstraint, + Table, + UniqueConstraint, +) from sqlalchemy.sql import Executable, sqltypes from typing_extensions import Self @@ -28,10 +34,12 @@ from datafaker.settings import get_source_dsn, get_source_schema from datafaker.utils import ( MaybeAsyncEngine, + constraint_name, create_db_engine, download_table, get_columns_assigned, get_property, + get_property_or_none, get_related_table_names, get_row_generators, get_sync_engine, @@ -66,8 +74,8 @@ class FunctionCall: """Which function to call with what.""" function_name: str - args: list[Any] - kwargs: dict[str, Any] + args: Sequence[Any] + kwargs: Mapping[str, Any] @dataclass @@ -110,18 +118,6 @@ def make_column_choices( ] -@dataclass -class _PrimaryConstraint: - """ - Describes a Uniqueness constraint for a multi-column primary key. - - Not a real constraint, but enough to write df.py. - """ - - columns: list[Column] - name: str - - @dataclass class TableGeneratorInfo: """Contains the df.py content related to regular tables.""" @@ -132,7 +128,7 @@ class TableGeneratorInfo: column_choices: list[ColumnChoice] rows_per_pass: int row_gens: list[RowGeneratorInfo] = field(default_factory=list) - unique_constraints: Sequence[Union[UniqueConstraint, _PrimaryConstraint]] = field( + unique_constraints: Sequence[ColumnCollectionConstraint] = field( default_factory=list ) @@ -467,19 +463,6 @@ def _get_provider_for_column(column: Column) -> Tuple[list[str], str, dict[str, return variable_names, generator_function, generator_arguments -def _constraint_sort_key(constraint: UniqueConstraint) -> str: - """Extract a string out of a UniqueConstraint that is unique to that constraint. - - We sort the constraints so that the output of make_tables is deterministic, this is - the sort key. - """ - return ( - constraint.name - if isinstance(constraint.name, str) - else "_".join(map(str, constraint.columns)) - ) - - def _get_generator_for_table( table_config: Mapping[str, Any], table: Table, @@ -491,12 +474,12 @@ def _get_generator_for_table( for constraint in table.constraints if isinstance(constraint, UniqueConstraint) ), - key=_constraint_sort_key, + key=constraint_name, ) primary_keys = [c for c in table.columns if c.primary_key] - constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = unique_constraints + constraints: Sequence[ColumnCollectionConstraint] = unique_constraints if 1 < len(primary_keys): - primary_constraint = _PrimaryConstraint( + primary_constraint = PrimaryKeyConstraint( columns=primary_keys, name=make_primary_key_name(table.name) ) constraints = unique_constraints + [primary_constraint] @@ -514,7 +497,7 @@ def _get_generator_for_table( class_name=table.name.title().replace(".", "") + "Generator", nonnull_columns=nonnull_columns, column_choices=column_choices, - rows_per_pass=get_property(table_config, "num_rows_per_pass", int, 1), + rows_per_pass=get_property(table_config, "num_rows_per_pass", 1), unique_constraints=constraints, ) @@ -583,6 +566,7 @@ def make_vocabulary_tables( @dataclass +# pylint: disable=too-many-instance-attributes class GenerationInfo: """Information for the generation of all data.""" @@ -618,14 +602,16 @@ def get_generation_info( :return: A string that is a valid Python module, once written to file. """ - row_generator_module_name = get_property( - config, "row_generators_module", str | None, None + row_generator_module_name = get_property_or_none( + config, "row_generators_module", str + ) + story_generator_module_name = get_property_or_none( + config, "story_generators_module", str ) - story_generator_module_name = get_property( - config, "story_generators_module", str | None, None + object_instantiation: dict[str, Any] = get_property( + config, "object_instantiation", {} ) - object_instantiation = get_property(config, "object_instantiation", dict, {}) - tables_config = get_property(config, "tables", dict, {}) + tables_config: dict[str, Any] = get_property(config, "tables", {}) tables: list[TableGeneratorInfo] = [] vocabulary_tables: list[VocabularyTableGeneratorInfo] = [] @@ -654,8 +640,8 @@ def get_generation_info( story_generators = _get_story_generators(config) - max_unique_constraint_tries = get_property( - config, "max-unique-constraint-tries", int | None, None + max_unique_constraint_tries = get_property_or_none( + config, "max-unique-constraint-tries", int ) return GenerationInfo( provider_imports=PROVIDER_IMPORTS, @@ -724,7 +710,7 @@ def make_tables_file( if parquet_dir is not None: extra_meta = get_parquet_orm(parquet_dir) if extra_meta: - md_tables = get_property(meta_dict, "tables", dict, {}) + md_tables: dict[str, Any] = get_property(meta_dict, "tables", {}) new_tables = {**extra_meta, **md_tables} meta_dict["tables"] = new_tables diff --git a/datafaker/populate.py b/datafaker/populate.py index 3d13c5ed..d8879ce1 100644 --- a/datafaker/populate.py +++ b/datafaker/populate.py @@ -1,12 +1,12 @@ +"""Put the generated values into the database, obeying other restrictions.""" from collections.abc import Iterable, Mapping, MutableMapping, Sequence -from pathlib import Path -from typing import Any, Callable +from typing import Any import sqlalchemy from mimesis import Generic from mimesis.locales import Locale -from datafaker.base import ColumnPresence, FileUploader +from datafaker.base import ColumnPresence from datafaker.make import FunctionCall, TableGeneratorInfo from datafaker.providers import ( BytesProvider, @@ -18,10 +18,11 @@ TimespanProvider, WeightedBooleanProvider, ) -from datafaker.utils import get_vocabulary_table_names, import_file +from datafaker.utils import constraint_name, import_file def make_generic() -> Generic: + """Make the generic provider instance.""" g = Generic(locale=Locale.EN_GB) g.add_providers( BytesProvider, @@ -36,19 +37,6 @@ def make_generic() -> Generic: return g -generic = make_generic() - - -def reset_generic() -> None: - """ - Reset all the generators. - - Only really useful in test code. - """ - global generic - generic = make_generic() - - def _eval_structure(config: Any, context: Mapping) -> Any: """ Turn a structure from ``config.yaml`` into a Python object. @@ -58,6 +46,7 @@ def _eval_structure(config: Any, context: Mapping) -> Any: """ if isinstance(config, str): try: + # pylint: disable=eval-used return eval(config, None, context) except SyntaxError as exc: raise exc @@ -82,19 +71,19 @@ def _get_object(class_name: str, context: Mapping) -> Any: """ parts = class_name.split(".") if parts[0] not in context: - raise ValueError('No such object "%"', parts[0]) + raise ValueError(f'No such object "{parts[0]}"') value = context[parts[0]] so_far = parts[0] for part in parts[1:]: so_far += "." + part if not hasattr(value, part): - raise ValueError('No such attribute "%"', so_far) + raise ValueError(f'No such attribute "{so_far}"') value = getattr(value, part) return value def _call_from_context( - callable_name: str, args: list[Any], kwargs: dict[str, Any], context: Mapping + callable_name: str, args: Sequence[Any], kwargs: Mapping[str, Any], context: Mapping ) -> Any: """ Call a callable from the classes (or functions) in the context. @@ -104,7 +93,7 @@ def _call_from_context( :return: Constructed object, or None if this did not work. """ cls = _get_object(callable_name, context) - if not isinstance(cls, Callable): + if not callable(cls): return None arg_objs = [_eval_structure(arg, context) for arg in args] kwarg_objs = {k: _eval_structure(v, context) for k, v in kwargs.items()} @@ -112,6 +101,7 @@ def _call_from_context( def call_function(fn: FunctionCall, context: Mapping) -> Any: + """Call ``fn`` within the provided context.""" return _call_from_context( fn.function_name, fn.args, @@ -128,6 +118,7 @@ def get_symbols( metadata: sqlalchemy.MetaData, ) -> dict[str, Any]: """Get the symbols that may be referred to by various configuration settings.""" + generic = make_generic() symbols = { "metadata": metadata, "generic": generic, @@ -180,6 +171,8 @@ def _get_symbols_instantiation(symbols: dict[str, Any], objs: dict[str, Any]) -> class TableGenerator: + """Puts generated values into a destination table.""" + def __init__( self, dst_db_conn: sqlalchemy.Connection, @@ -203,27 +196,27 @@ def __init__( self.context: Mapping = {} with dst_db_conn.begin(): for constraint in table_data.unique_constraints: - expr = sqlalchemy.select(constraint.columns) + expr = sqlalchemy.select(*constraint.columns) query_result = dst_db_conn.execute(expr).fetchall() - self.existing_constraint_hashes[constraint.name] = set( - [hash(tuple(result)) for result in query_result] - ) + self.existing_constraint_hashes[constraint_name(constraint)] = { + hash(tuple(result)) for result in query_result + } @property - def num_rows_per_pass(self): + def num_rows_per_pass(self) -> int: """Get the number of rows this generator should produce relative to all the rest.""" return self.table_data.rows_per_pass @property - def name(self): + def name(self) -> str: """Get the name of the table whose rows we are generating.""" return self.table_data.table_name def set_context(self, context: Mapping) -> None: - """Sets all the Python symbols that must be known to the configuration.""" + """Set all the Python symbols that must be known to the configuration.""" self.context = context - def __call__(self, db_conn: sqlalchemy.Connection): + def __call__(self, db_conn: sqlalchemy.Connection) -> dict[str, Any]: """Generate some rows of the relevant table in the database.""" result: dict[str, Any] = {} columns_to_generate = set(self.table_data.nonnull_columns) @@ -238,7 +231,9 @@ def __call__(self, db_conn: sqlalchemy.Connection): while columns_to_generate: if max_tries == 0: raise RuntimeError( - f"Failed to satisfy unique constraints for table {self.table_data.table_name} after {self.max_unique_constraint_tries} attempts." + "Failed to satisfy unique constraints for table" + f" {self.table_data.table_name} after" + f" {self.max_unique_constraint_tries} attempts." ) if max_tries is not None: max_tries -= 1 @@ -256,11 +251,14 @@ def __call__(self, db_conn: sqlalchemy.Connection): columns_to_generate = set() for constraint in self.table_data.unique_constraints: cf_hash = hash(tuple(result[col.name] for col in constraint.columns)) - if cf_hash in self.existing_constraint_hashes[constraint.name]: + if ( + cf_hash + in self.existing_constraint_hashes[constraint_name(constraint)] + ): columns_to_generate.update(c.name for c in constraint.columns) for constraint in self.table_data.unique_constraints: cf_hash = hash(tuple(result[col.name] for col in constraint.columns)) - self.existing_constraint_hashes[constraint.name].add(cf_hash) + self.existing_constraint_hashes[constraint_name(constraint)].add(cf_hash) return result @@ -269,7 +267,7 @@ def _make_table_generator( table_data: TableGeneratorInfo, max_unique_constraint_tries: int | None, context: Mapping, -): +) -> TableGenerator: """Make a ``TableGenerator`` with context attached.""" gen = TableGenerator(dst_db_conn, table_data, max_unique_constraint_tries) gen.set_context(context) @@ -281,7 +279,7 @@ def get_table_generator_dict( tables_data: Iterable[TableGeneratorInfo], max_unique_constraint_tries: int | None, context: Mapping, -): +) -> dict[str, TableGenerator]: """Get a dict of table names to row generators that generate rows for that table.""" return { table_data.table_name: _make_table_generator( diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 62bc01c7..69516acb 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -8,7 +8,12 @@ from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.sql import schema, sqltypes -from datafaker.utils import get_property, make_foreign_key_name, split_column_full_name +from datafaker.utils import ( + constraint_name, + get_property, + make_foreign_key_name, + split_column_full_name, +) TableT = dict[str, typing.Any] @@ -244,7 +249,7 @@ def dict_to_unique(rep: dict) -> schema.UniqueConstraint: def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: """Render a dict representation of a uniqueness constraint.""" return { - "name": constraint.name, + "name": constraint_name(constraint), "columns": [str(col.name) for col in constraint.columns], } @@ -315,8 +320,8 @@ def should_ignore_fk(tables_dict: dict[str, TableT], fk: str) -> bool: :param fk: The name of the foreign key. """ (table, _column) = split_column_full_name(fk) - td = get_property(tables_dict, table, dict, {}) - return get_property(td, "ignore", bool, False) + td: dict[str, TableT] = get_property(tables_dict, table, {}) + return get_property(td, "ignore", False) def _always_false(_: str) -> bool: diff --git a/datafaker/utils.py b/datafaker/utils.py index d33cf751..2147aae5 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -21,7 +21,6 @@ Generic, Iterable, Optional, - Type, TypeVar, Union, ) @@ -38,6 +37,7 @@ from sqlalchemy.orm import Session from sqlalchemy.schema import ( AddConstraint, + ColumnCollectionConstraint, DropConstraint, ForeignKeyConstraint, MetaData, @@ -404,21 +404,36 @@ def get_flag(maybe_dict: Any, key: Any) -> bool: return isinstance(maybe_dict, Mapping) and maybe_dict.get(key, False) -def get_property(maybe_dict: Any, key: Any, required_type: type[T], default: T) -> T: +def get_property(maybe_dict: Any, key: Any, default: T) -> T: """ Get a specific property from a dict or a default if that does not exist. :param maybe_dict: A mapping, or possibly not. :param key: A key in ``maybe_dict``, or possibly not. - :param required_type: The type ``maybe_dict[key]`` needs to be an instance of. :param default: The return value if ``maybe_dict`` is not a mapping, - or if ``key`` is not a key of ``maybe_dict``. + or if ``key`` is not a key of ``maybe_dict``. Do not pass ``None``! + if you want None as the default, please use get_property_or_none :return: ``maybe_dict[key]`` if this makes sense, or ``default`` if not. """ if not isinstance(maybe_dict, Mapping): return default v = maybe_dict.get(key, default) - return v if isinstance(v, required_type) else default + return v if isinstance(v, type(default)) else default + + +def get_property_or_none(maybe_dict: Any, key: Any, type_: type[T]) -> T | None: + """ + Get a specific property from a dict or None if that does not exist. + + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :param type_: The type that the value retrieved should have. + :return: ``maybe_dict[key]`` if this makes sense, or ``default`` if not. + """ + if not isinstance(maybe_dict, Mapping) or key not in maybe_dict: + return None + v = maybe_dict[key] + return v if isinstance(v, type_) else None def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: @@ -435,6 +450,16 @@ def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: return False +def constraint_name(constraint: ColumnCollectionConstraint) -> str: + """Get the constraint name, synthesising it if it does not exist explicitly.""" + name = constraint.name + if isinstance(name, str): + return name + joined = "_".join(constraint.columns.keys()) + kind = constraint.__visit_name__.split("_", 1)[0] + return f"{joined}_{kind}" + + def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: """ Test if the constraint refers to a table marked as ignored in ``config.yaml``. @@ -546,9 +571,9 @@ def get_row_generators( :param table_config: The element from the ``tables:`` stanza of ``config.xml``. :return: Pair of (name, row generator config). """ - rgs = get_property(table_config, "row_generators", list, []) + rgs: list[Any] = get_property(table_config, "row_generators", []) for rg in rgs: - name = rg.get("name", None) + name = get_property_or_none(rg, "name", str) if name: yield (name, rg) diff --git a/tests/test_create.py b/tests/test_create.py index 48b0e695..dc148d47 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -352,7 +352,7 @@ class CreateDataTestCase(RequiresDBTestCase): def test_create_data_minimal(self) -> None: """Test creating one table with one PK column.""" - config = {} + config: dict[str, Any] = {} orm = { "tables": { "one": { @@ -385,6 +385,7 @@ def test_create_data_minimal(self) -> None: self.assertEqual(row_counts["one"], generate_count) def test_unique_constraint_minimal(self) -> None: + """Test that unique constraints cause a failure with a constant provider.""" config = { "tables": { "one": { diff --git a/tests/test_functional.py b/tests/test_functional.py index a77f9c96..bfc6e8cc 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -277,12 +277,10 @@ def test_workflow_maximal_args(self) -> None: "Creating data.", "Generating data for story 'story_generators.short_story'", "Generating data for story 'story_generators.full_row_story'", - "Generating data for story 'story_generators.full_row_story'", "Generating data for story 'story_generators.long_story'", "Generating data for table 'data_type_test'", "Generating data for table 'no_pk_test'", "Generating data for table 'person'", - "Generating data for table 'person'", "Generating data for table 'strange_type_table'", "Generating data for table 'unique_constraint_test'", "Generating data for table 'unique_constraint_test2'", diff --git a/tests/test_rst.py b/tests/test_rst.py index 5baf6062..ee89b5b5 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -10,20 +10,6 @@ from sphinxcontrib.mermaid import Mermaid -def _level_to_string(level: int) -> str: - """Get a string description of an error level.""" - return ["Severe", "Error", "Warning"][level] - - -def _error_message(lint_error: Any) -> str: - """Turn a linting error into an error message.""" - source = getattr(lint_error, "source") - line = getattr(lint_error, "line") - level = _level_to_string(getattr(lint_error, "level")) - message = getattr(lint_error, "full_message") - return f"{source}({line}): {level}: {message}" - - def _level_to_string(level: int) -> str: """Get a string description of an error level.""" return ["Severe", "Error", "Warning"][level] diff --git a/tests/utils.py b/tests/utils.py index b07aadae..e364bc5d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,14 +27,12 @@ from datafaker.create import create_db_data_into, create_db_tables_into from datafaker.interactive.base import DbCmd from datafaker.make import make_src_stats, make_tables_file -from datafaker.populate import reset_generic from datafaker.utils import ( MaybeAsyncEngine, T, create_db_engine, create_db_engine_dst, get_sync_engine, - import_file, sorted_non_vocabulary_tables, ) @@ -245,7 +243,6 @@ def setUp(self) -> None: """Set up the test case with an actual orm.yaml file.""" super().setUp() settings.get_settings.cache_clear() - reset_generic() if self.use_temporary_cwd: self.start_dir = os.getcwd() self.working_dir = mkdtemp("test") From 675082f9f21a34050f2907a4409353b54308c04e Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 24 Mar 2026 18:06:36 +0000 Subject: [PATCH 39/44] Version bump to 0.3.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 912fc6ed..7983ab2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "datafaker" -version = "0.2.3" +version = "0.3.0" description = "Generates fake SQL data" authors = ["Tim Band <3266052+tim-band@users.noreply.github.com>"] license = "MIT" From 48e9ec861c9ecb658ebb335feb914322f4043ddd Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 9 Apr 2026 11:37:11 +0100 Subject: [PATCH 40/44] Removed some more df.py remenants --- datafaker/generators/base.py | 2 +- datafaker/main.py | 3 +-- datafaker/make.py | 25 ++++--------------------- docs/source/duckdb.rst | 3 +-- docs/source/introduction.rst | 6 ------ docs/source/loan_data.rst | 13 ++----------- docs/source/overview.rst | 1 - docs/source/quickstart.rst | 7 ++----- tests/test_functional.py | 1 - 9 files changed, 11 insertions(+), 50 deletions(-) diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py index 0ded64c1..613a191d 100644 --- a/datafaker/generators/base.py +++ b/datafaker/generators/base.py @@ -42,7 +42,7 @@ class Generator(ABC): @abstractmethod def function_name(self) -> str: - """Get the name of the generator function to put into df.py.""" + """Get the name of the generator function to call to generate the data.""" def name(self) -> str: """ diff --git a/datafaker/main.py b/datafaker/main.py index c80af073..1a709bd0 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -55,7 +55,6 @@ ORM_FILENAME: Final[str] = "orm.yaml" CONFIG_FILENAME: Final[str] = "config.yaml" -DF_FILENAME: Final[str] = "df.py" STATS_FILENAME: Final[str] = "src-stats.yaml" app = Typer(no_args_is_help=True) @@ -270,7 +269,7 @@ def create_generators( dir_okay=False, ), _df_file: Path = Option( - DF_FILENAME, + None, help="Path to write Python generators to.", dir_okay=False, ), diff --git a/datafaker/make.py b/datafaker/make.py index 997322de..c9bb1667 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -13,8 +13,6 @@ import snsql import typer import yaml -from black import FileMode, format_str -from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider from sqlalchemy import CursorResult, Engine, MetaData, text from sqlalchemy.dialects import postgresql @@ -58,13 +56,10 @@ if issubclass(entry, BaseProvider) and entry.__module__ == "datafaker.providers": PROVIDER_IMPORTS.append(entry_name) -TEMPLATE_DIRECTORY: Final[Path] = Path(__file__).parent / "templates/" -DF_TEMPLATE_FILENAME: Final[str] = "df.py.j2" - @dataclass class VocabularyTableGeneratorInfo: - """Contains the df.py content related to vocabulary tables.""" + """Contains the vocabulary tables to be generated.""" variable_name: str table_name: str @@ -82,7 +77,7 @@ class FunctionCall: @dataclass class RowGeneratorInfo: - """Contains the df.py content related to row generators of a table.""" + """Contains the row generators of a table.""" variable_names: list[str] function_call: FunctionCall @@ -122,7 +117,7 @@ def make_column_choices( @dataclass class TableGeneratorInfo: - """Contains the df.py content related to regular tables.""" + """Contains the tables that need data generation.""" class_name: str table_name: str @@ -137,7 +132,7 @@ class TableGeneratorInfo: @dataclass class StoryGeneratorInfo: - """Contains the df.py content related to story generators.""" + """Contains the story generators.""" wrapper_name: str function_call: FunctionCall @@ -669,18 +664,6 @@ def get_generation_info( ) -def generate_df_content(template_context: Mapping[str, Any]) -> str: - """Generate the content of the df.py file as a string.""" - environment: Environment = Environment( - loader=FileSystemLoader(TEMPLATE_DIRECTORY), - trim_blocks=True, - lstrip_blocks=True, - ) - df_template: Template = environment.get_template(DF_TEMPLATE_FILENAME) - template_output: str = df_template.render(template_context) - return format_str(template_output, mode=FileMode()) - - def _get_generator_for_existing_vocabulary_table( table: Table, ) -> VocabularyTableGeneratorInfo: diff --git a/docs/source/duckdb.rst b/docs/source/duckdb.rst index 69ddc724..47004fd4 100644 --- a/docs/source/duckdb.rst +++ b/docs/source/duckdb.rst @@ -96,7 +96,7 @@ Using DuckDB to write fake Parquet or CSV files You cannot use an in-memory DuckDB for the destination database because it needs to survive multiple calls to ``datafaker``, but Datafaker will create the DuckDB file for you if you set the `DST_DSN` environment variable appropriately. -After using ``datafaker create-tables``, ``datafaker create-generators``, and ``datafaker create-data``, +After using ``datafaker create-tables`` and ``datafaker create-data``, you now have a database file containing the fake data. If you want CSV or parquet files you can use the following commands: .. code-block:: shell @@ -166,7 +166,6 @@ and can create the fake data parquet files in a new directory called ``fake``: .. code-block:: shell datafaker create-tables - datafaker create-generators datafaker create-data --num-passes 100 mkdir fake datafaker dump-data --parquet --output fake diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index 3e3523c1..60c02f91 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -70,13 +70,9 @@ And let's populate it with the fake data: export DST_DSN='postgresql://tim:password@localhost/fake_pagila' export DST_SCHEMA='public' - datafaker create-generators datafaker create-tables datafaker create-data -``create-generators`` creates a Python file called ``df.py``. -You can edit this file if you want, but it is much easier to edit ``config.yaml`` and call ``datafaker create-generators --force`` to regenerate this file. - You will notice that ``create-tables`` produces a couple of warnings, and PostgreSQL complains when ``datafaker`` tries to create the data. The warnings are that ``datafaker`` doesn't understand the special PostgresSQL types ``TSVECTOR`` and ``ARRAY``, so it doesn't know how to generate data for those columns. Because it doesn't know how to generate data for those columns it will just use NULLs, and the ``film.fulltext`` column cannot be NULL, so creating the data fails. @@ -313,7 +309,6 @@ option to create multiple rows of output. .. code-block:: shell - datafaker create-generators --force datafaker create-data --num-passes 3 Now let's have a look at what data we have in the destination database: @@ -589,7 +584,6 @@ followed by re-generating the data: Do you want to save this configuration? (yes/no/cancel) yes $ datafaker remove-data --yes - $ datafaker create-generators --force $ datafaker create-data --num-passes 3 $ datafaker dump-data --output - --table film description,film_id,fulltext,language_id,last_update,length,original_language_id,rating,release_year,rental_duration,rental_rate,replacement_cost,special_features,title diff --git a/docs/source/loan_data.rst b/docs/source/loan_data.rst index 400a3ec2..85abb3af 100644 --- a/docs/source/loan_data.rst +++ b/docs/source/loan_data.rst @@ -71,12 +71,6 @@ we see that they are always 0 or 1 so we will pick randomly from 0 and 1 for our .. literalinclude:: ../../examples/loans/config1.yaml :language: yaml -We run SqlSynthGen's ``create-generators`` command to create ``df.py``, which contains a generator class for each table in the source database: - -.. code-block:: console - - $ sqlsynthgen create-generators --config config.yaml - We then run SqlSynthGen's ``create-tables`` command to create the tables in the destination database: .. code-block:: console @@ -108,7 +102,6 @@ We can export the vocabularies to ``.yaml`` files, delete the old synthetic data .. code-block:: console - $ sqlsynthgen create-generators $ sqlsynthgen remove-data $ sqlsynthgen create-vocab $ sqlsynthgen create-data --num-passes 100 @@ -160,11 +153,10 @@ We add it manually to the orm.py file: ) ... -We'll need to recreate the ``df.py`` file, the destination database and the data: +We'll need to recreate the destination database and the data: .. code-block:: console - $ sqlsynthgen create-generators --config-file config.yaml --force $ sqlsynthgen remove-tables --yes $ sqlsynthgen create-tables $ sqlsynthgen create-vocab @@ -212,11 +204,10 @@ We define a custom row-generator to use the source statistics and Python's ``ran .. literalinclude:: ../../examples/loans/my_row_generators.py :language: python -As before, we will need to re-create ``df.py`` and the data. +We recreate the data: .. code-block:: console - $ sqlsynthgen create-generators --config-file config.yaml --force $ sqlsynthgen make-stats --config-file config.yaml --force $ sqlsynthgen remove-data --yes $ sqlsynthgen create-vocab diff --git a/docs/source/overview.rst b/docs/source/overview.rst index 6f727023..f13c1a58 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -319,7 +319,6 @@ as the sensitive data is no longer accessed by Datafaker. The remaining commands are: - ``datafaker create-tables`` creates the structure of the destination database to match (as much as is requested) the structure of the source database -- ``datafaker create-generators`` creates Python code files that will actually generate the data (this phase might be removed in a future version of Datafaker) - ``datafaker create-data`` writes fake data into the destination database. As these operations require no access to the sensitive data, this phase can be diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 43722aa6..7aa906c8 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -21,7 +21,6 @@ After :ref:`Installation `, we can run ``datafaker`` to see t configure-missing Interactively set the missingness of the... configure-tables Interactively set tables to ignored, vocabulary... create-data Populate the schema in the target directory with... - create-generators Make a datafaker file of generator classes. create-tables Create schema from the ORM YAML file. create-vocab Import vocabulary data into the target database. list-tables List the names of tables @@ -561,13 +560,11 @@ Whichever we chose, now we can create the generators Python file and generate th $ datafaker create-tables $ datafaker create-vocab - $ datafaker create-generators $ datafaker create-data --num-passes 10 The first of these uses ``orm.yaml`` to create the destination database. The second uses all the ``.yaml.gz`` (or ``.yaml``) files representing the vocabulary tables (this can take hours, too). -The third uses ``config.yaml`` to create a file ``df.py`` file containing code to call the generators as configured. -The last one actually generates the data. ``--num-passes`` controls how many rows are generated. +The third actually generates the data. ``--num-passes`` controls how many rows are generated. At present the only ways to generate different numbers of rows for different tables is to configure ``num_rows_per_pass`` in ``config.yaml``: .. code-block:: yaml @@ -575,7 +572,7 @@ At present the only ways to generate different numbers of rows for different tab observation: num_rows_per_pass: 50 -This makes every call to ``create-data`` produce 50 rows in the ``observation`` table (each time you change ``config.yaml` you need to re-run ``create-generators``). +This makes every call to ``create-data`` produce 50 rows in the ``observation`` table. If you call ``create-data`` multiple times you get more data added to whatever already exists. Call ``remove-data`` to remove all rows from all non-vocabulary tables. You can call ``remove-vocab`` to remove all rows from all vocabulary tables, and you can call ``remove-tables`` to empty the database completely. diff --git a/tests/test_functional.py b/tests/test_functional.py index 91f741ca..bf385aff 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -21,7 +21,6 @@ class DBFunctionalTestCaseBase(RequiresDBTestCase): examples_dir = Path("tests/examples") orm_file_path = Path("orm.yaml") - datafaker_file_path = Path("df.py") generator_file_paths = tuple( map(Path, ("story_generators.py", "row_generators.py")), From fba596ad957573056ef994ff16bb79d3ce419fcb Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 9 Apr 2026 16:10:25 +0100 Subject: [PATCH 41/44] Rename interactive Generators to Proposers --- datafaker/generators/__init__.py | 53 ------------ datafaker/interactive/generators.py | 36 ++++---- datafaker/main.py | 2 +- datafaker/proposers/__init__.py | 52 ++++++++++++ datafaker/{generators => proposers}/base.py | 83 ++++++++++--------- datafaker/{generators => proposers}/choice.py | 14 ++-- .../{generators => proposers}/continuous.py | 34 ++++---- .../{generators => proposers}/mimesis.py | 56 ++++++------- .../{generators => proposers}/partitioned.py | 18 ++-- docs/source/duckdb.rst | 7 ++ tests/test_interactive_generators.py | 6 +- ...test_interactive_generators_partitioned.py | 8 +- 12 files changed, 189 insertions(+), 180 deletions(-) delete mode 100644 datafaker/generators/__init__.py create mode 100644 datafaker/proposers/__init__.py rename datafaker/{generators => proposers}/base.py (87%) rename datafaker/{generators => proposers}/choice.py (98%) rename datafaker/{generators => proposers}/continuous.py (96%) rename datafaker/{generators => proposers}/mimesis.py (93%) rename datafaker/{generators => proposers}/partitioned.py (98%) diff --git a/datafaker/generators/__init__.py b/datafaker/generators/__init__.py deleted file mode 100644 index 42c96040..00000000 --- a/datafaker/generators/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Generators write generator function definitions and queries into config.yaml.""" - -from collections.abc import Mapping -from functools import lru_cache - -from datafaker.generators.base import ( - ConstantGeneratorFactory, - GeneratorFactory, - MultiGeneratorFactory, -) -from datafaker.generators.choice import ChoiceGeneratorFactory -from datafaker.generators.continuous import ( - ContinuousDistributionGeneratorFactory, - ContinuousLogDistributionGeneratorFactory, - MultivariateLogNormalGeneratorFactory, - MultivariateNormalGeneratorFactory, -) -from datafaker.generators.mimesis import ( - MimesisDateGeneratorFactory, - MimesisDateTimeGeneratorFactory, - MimesisFloatGeneratorFactory, - MimesisIntegerGeneratorFactory, - MimesisStringGeneratorFactory, - MimesisTimeGeneratorFactory, -) -from datafaker.generators.partitioned import ( - NullPartitionedLogNormalGeneratorFactory, - NullPartitionedNormalGeneratorFactory, -) - - -def everything_factory(config: Mapping) -> GeneratorFactory: - """ - Get a factory that encapsulates all the other factories. - - :param config: The ``config.yaml`` configuration. - """ - return MultiGeneratorFactory( - MimesisStringGeneratorFactory(), - MimesisIntegerGeneratorFactory(), - MimesisFloatGeneratorFactory(), - MimesisDateGeneratorFactory(), - MimesisDateTimeGeneratorFactory(), - MimesisTimeGeneratorFactory(), - ContinuousDistributionGeneratorFactory(), - ContinuousLogDistributionGeneratorFactory(), - ChoiceGeneratorFactory(), - ConstantGeneratorFactory(), - MultivariateNormalGeneratorFactory(), - MultivariateLogNormalGeneratorFactory(), - NullPartitionedNormalGeneratorFactory(config), - NullPartitionedLogNormalGeneratorFactory(config), - ) diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py index ecc8403f..75251ea0 100644 --- a/datafaker/interactive/generators.py +++ b/datafaker/interactive/generators.py @@ -8,9 +8,9 @@ import sqlalchemy from sqlalchemy import Column -from datafaker.generators import everything_factory -from datafaker.generators.base import Generator, PredefinedGenerator from datafaker.interactive.base import DbCmd, TableEntry, fk_column_name, or_default +from datafaker.proposers import everything_factory +from datafaker.proposers.base import PredefinedProposer, Proposer from datafaker.utils import ( get_columns_assigned, get_row_generators, @@ -26,7 +26,7 @@ class GeneratorInfo: """A generator and the columns it assigns to.""" columns: list[str] - gen: Generator | None + gen: Proposer | None @dataclass @@ -120,7 +120,7 @@ def make_table_entry( new_generator_infos.append( GeneratorInfo( columns=actual_collist.copy(), - gen=PredefinedGenerator(table_name, rg, self.config), + gen=PredefinedProposer(table_name, rg, self.config), ) ) columns_assigned_so_far |= colset @@ -158,7 +158,7 @@ def __init__( :param config: Configuration loaded from ``config.yaml`` """ super().__init__(settings) - self.generators: list[Generator] | None = None + self.generators: list[Proposer] | None = None self.generator_index = 0 self.generators_valid_columns: Optional[tuple[int, list[str]]] = None self.set_prompt() @@ -271,7 +271,7 @@ def _copy_entries(self) -> None: src_stats = self._remove_auto_src_stats() for entry in self.table_entries: rgs = [] - new_gens: list[Generator] = [] + new_gens: list[Proposer] = [] for generator in entry.new_generators: if generator.gen is not None: new_gens.append(generator.gen) @@ -316,7 +316,7 @@ def _copy_entries(self) -> None: def _find_old_generator( self, entry: GeneratorCmdTableEntry, columns: Iterable[str] - ) -> Generator | None: + ) -> Proposer | None: """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) for gen in entry.old_generators: @@ -566,13 +566,13 @@ def _generators_valid(self) -> bool: self._get_column_names(), ) - def _get_generator_proposals(self) -> list[Generator]: + def _get_generator_proposals(self) -> list[Proposer]: """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" if not self._generators_valid(): self.generators = None if self.generators is None: columns = self._column_metadata() - gens = everything_factory(self.config).get_generators( + gens = everything_factory(self.config).get_proposers( columns, self.sync_engine ) sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) @@ -615,7 +615,7 @@ def do_compare(self, arg: str) -> None: for x in self._get_column_data(limit, to_str=str) ] } - gens: list[Generator] = self._get_generator_proposals() + gens: list[Proposer] = self._get_generator_proposals() table_name = self.table_name() for argument in args: if argument.isdigit(): @@ -630,7 +630,7 @@ def do_c(self, arg: str) -> None: """Synonym for compare.""" self.do_compare(arg) - def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: + def _print_values_queried(self, table_name: str, n: int, gen: Proposer) -> None: """ Print the values queried from the database for this generator. @@ -653,7 +653,7 @@ def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None self._print_select_aggregate_query(table_name, gen) self._print_custom_queries(gen) - def _print_custom_queries(self, gen: Generator) -> None: + def _print_custom_queries(self, gen: Proposer) -> None: """ Print all the custom queries and all the values they get in this case. @@ -700,9 +700,7 @@ def _get_custom_queries_from( if k in actual: self._get_custom_queries_from(out, v, actual[k]) - def _get_aggregate_query( - self, gens: list[Generator], table_name: str - ) -> str | None: + def _get_aggregate_query(self, gens: list[Proposer], table_name: str) -> str | None: clauses = [ f'{q["clause"]} AS {n}' for gen in gens @@ -712,7 +710,7 @@ def _get_aggregate_query( return None return f"SELECT {', '.join(clauses)} FROM {table_name}" - def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: + def _print_select_aggregate_query(self, table_name: str, gen: Proposer) -> None: """ Print the select aggregate query and all the values it gets in this case. @@ -803,7 +801,7 @@ def do_p(self, arg: str) -> None: """Synonym for propose.""" self.do_propose(arg) - def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: + def get_proposed_generator_by_name(self, gen_name: str) -> Proposer | None: """Find a generator by name from the list of proposals.""" for gen in self._get_generator_proposals(): if gen.name() == gen_name: @@ -816,7 +814,7 @@ def do_set(self, arg: str) -> None: self.print("Please run 'propose' before 'set '") return gens = self._get_generator_proposals() - new_gen: Generator | None + new_gen: Proposer | None if arg.isdigit(): index = int(arg) if index < 1: @@ -837,7 +835,7 @@ def do_set(self, arg: str) -> None: self.set_generator(new_gen) self._go_next() - def set_generator(self, gen: Generator | None) -> None: + def set_generator(self, gen: Proposer | None) -> None: """Set the current column's generator.""" (table, gen_info) = self._get_table_and_generator() if table is None: diff --git a/datafaker/main.py b/datafaker/main.py index 1a709bd0..c0edfd85 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -483,7 +483,7 @@ def configure_missing( return content = yaml.dump(config_updated) config_file.write_text(content, encoding="utf-8") - logger.debug("Generators missingness in %s.", config_file) + logger.debug("Missingness generators in %s.", config_file) @app.command() diff --git a/datafaker/proposers/__init__.py b/datafaker/proposers/__init__.py new file mode 100644 index 00000000..69993942 --- /dev/null +++ b/datafaker/proposers/__init__.py @@ -0,0 +1,52 @@ +"""Generators write generator function definitions and queries into config.yaml.""" + +from collections.abc import Mapping + +from datafaker.proposers.base import ( + ConstantProposerFactory, + MultiProposerFactory, + ProposerFactory, +) +from datafaker.proposers.choice import ChoiceProposerFactory +from datafaker.proposers.continuous import ( + ContinuousDistributionProposerFactory, + ContinuousLogDistributionProposerFactory, + MultivariateLogNormalProposerFactory, + MultivariateNormalProposerFactory, +) +from datafaker.proposers.mimesis import ( + MimesisDateProposerFactory, + MimesisDateTimeProposerFactory, + MimesisFloatProposerFactory, + MimesisIntegerProposerFactory, + MimesisStringProposerFactory, + MimesisTimeProposerFactory, +) +from datafaker.proposers.partitioned import ( + NullPartitionedLogNormalProposerFactory, + NullPartitionedNormalProposerFactory, +) + + +def everything_factory(config: Mapping) -> ProposerFactory: + """ + Get a factory that encapsulates all the other factories. + + :param config: The ``config.yaml`` configuration. + """ + return MultiProposerFactory( + MimesisStringProposerFactory(), + MimesisIntegerProposerFactory(), + MimesisFloatProposerFactory(), + MimesisDateProposerFactory(), + MimesisDateTimeProposerFactory(), + MimesisTimeProposerFactory(), + ContinuousDistributionProposerFactory(), + ContinuousLogDistributionProposerFactory(), + ChoiceProposerFactory(), + ConstantProposerFactory(), + MultivariateNormalProposerFactory(), + MultivariateLogNormalProposerFactory(), + NullPartitionedNormalProposerFactory(config), + NullPartitionedLogNormalProposerFactory(config), + ) diff --git a/datafaker/generators/base.py b/datafaker/proposers/base.py similarity index 87% rename from datafaker/generators/base.py rename to datafaker/proposers/base.py index 613a191d..1692e3a5 100644 --- a/datafaker/generators/base.py +++ b/datafaker/proposers/base.py @@ -1,4 +1,4 @@ -"""Basic Generators and factories.""" +"""Basic Proposers and their factories.""" import re from abc import ABC, abstractmethod @@ -22,18 +22,18 @@ generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) -class GeneratorError(Exception): - """Error thrown from Datafaker Generators.""" +class ProposerError(Exception): + """Error thrown from Datafaker Proposers.""" -class Generator(ABC): +class Proposer(ABC): """ - Random data generator. + Random data generator proposer. - A generator is specific to a particular column in a particular table in - a particluar database. + A proposer is specific to a particular column (set) in a particular table + in a particluar database. - A generator knows how to fetch its summary data from the database, how to calculate + A proposer knows how to fetch its summary data from the database, how to calculate its fit (if apropriate) and which function actually does the generation. It also knows these summary statistics for the column it was instantiated on, @@ -46,10 +46,10 @@ def function_name(self) -> str: def name(self) -> str: """ - Get the name of the generator. + Get the name of the proposer. Usually the same as the function name, but can be different to distinguish - between generators that have the same function but different queries. + between proposers that have the same function but different queries. """ return self.function_name() @@ -128,8 +128,13 @@ def fit(self, default: float = -1) -> float: return default -class PredefinedGenerator(Generator): - """Generator built from an existing config.yaml.""" +class PredefinedProposer(Proposer): + """ + Proposer built from an existing config.yaml. + + Does not actually propose, it just represents generators + that have been defined previously. + """ SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") @@ -159,13 +164,13 @@ def __init__( config: Mapping[str, Any], ): """ - Initialise a generator from a config.yaml. + Initialise a proposer from a config.yaml. :param config: The entire configuration. :param generator_object: The part of the configuration at tables.*.row_generators """ logger.debug( - "Creating a PredefinedGenerator %s from table %s", + "Creating a PredefinedProposer %s from table %s", generator_object["name"], table_name, ) @@ -232,27 +237,27 @@ def actual_kwargs(self) -> dict[str, Any]: """Get the kwargs (summary statistics) this generator was instantiated with.""" # Run the queries from nominal_kwargs # ... - logger.error("PredefinedGenerator.actual_kwargs not implemented yet") + logger.error("PredefinedProposer.actual_kwargs not implemented yet") return {} def generate_data(self, count: int) -> list[Any]: """Generate ``count`` random data points for this column.""" # Call the function if we can. This could be tricky... # ... - logger.error("PredefinedGenerator.generate_data not implemented yet") + logger.error("PredefinedProposer.generate_data not implemented yet") return [] -class GeneratorFactory(ABC): - """A factory for making generators appropriate for a database column.""" +class ProposerFactory(ABC): + """A factory for making proposers appropriate for a database column.""" @abstractmethod - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine, - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" + ) -> Sequence[Proposer]: + """Get the proposers appropriate to these columns.""" def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: @@ -358,22 +363,22 @@ def fit_from_values(self, values: list[float]) -> float: return self.fit_from_counts(buckets) -class MultiGeneratorFactory(GeneratorFactory): +class MultiProposerFactory(ProposerFactory): """A composite factory.""" - def __init__(self, *factories: GeneratorFactory): - """Initialise a MultiGeneratorFactory.""" + def __init__(self, *factories: ProposerFactory): + """Initialise a MultiProposerFactory.""" super().__init__() self.factories = factories - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" + ) -> Sequence[Proposer]: + """Get the proposers appropriate to these columns.""" return [ generator for factory in self.factories - for generator in factory.get_generators(columns, engine) + for generator in factory.get_proposers(columns, engine) ] @@ -385,11 +390,11 @@ def get_column_type(column: Column) -> TypeEngine: return column.type -class ConstantGenerator(Generator): - """Generator that always produces the same value.""" +class ConstantProposer(Proposer): + """Proposer for a generator that always produces the same value.""" def __init__(self, value: Any) -> None: - """Initialise the ConstantGenerator.""" + """Initialise the ConstantProposer.""" super().__init__() self.value = value self.repr = repr(value) @@ -411,23 +416,23 @@ def generate_data(self, count: int) -> list[Any]: return [self.value for _ in range(count)] -class ConstantGeneratorFactory(GeneratorFactory): - """Just the null generator.""" +class ConstantProposerFactory(ProposerFactory): + """Propose just the null generator.""" - def get_generators( + def get_proposers( self, columns: list[Column], _engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate for these columns.""" if len(columns) != 1: return [] column = columns[0] if column.nullable: - return [ConstantGenerator(None)] + return [ConstantProposer(None)] c_type = get_column_type(column) if isinstance(c_type, String): - return [ConstantGenerator("")] + return [ConstantProposer("")] if isinstance(c_type, Numeric): - return [ConstantGenerator(0.0)] + return [ConstantProposer(0.0)] if isinstance(c_type, Integer): - return [ConstantGenerator(0)] + return [ConstantProposer(0)] return [] diff --git a/datafaker/generators/choice.py b/datafaker/proposers/choice.py similarity index 98% rename from datafaker/generators/choice.py rename to datafaker/proposers/choice.py index 6f153037..13fbbb78 100644 --- a/datafaker/generators/choice.py +++ b/datafaker/proposers/choice.py @@ -8,9 +8,9 @@ from sqlalchemy import Column, CursorResult, Engine, text -from datafaker.generators.base import ( - Generator, - GeneratorFactory, +from datafaker.proposers.base import ( + Proposer, + ProposerFactory, dist_gen, fit_from_buckets, ) @@ -44,7 +44,7 @@ def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None yield x -class ChoiceGenerator(Generator): +class ChoiceGenerator(Proposer): """Base generator for all generators producing choices of items.""" STORE_COUNTS = False @@ -287,15 +287,15 @@ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: self.cvs_not_suppressed = cvs_not_suppressed -class ChoiceGeneratorFactory(GeneratorFactory): +class ChoiceProposerFactory(ProposerFactory): """All generators that want an average and standard deviation.""" SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 7 - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] diff --git a/datafaker/generators/continuous.py b/datafaker/proposers/continuous.py similarity index 96% rename from datafaker/generators/continuous.py rename to datafaker/proposers/continuous.py index 3d9b9c80..ce2352a7 100644 --- a/datafaker/generators/continuous.py +++ b/datafaker/proposers/continuous.py @@ -10,18 +10,18 @@ from sqlalchemy.types import Integer, Numeric from typing_extensions import Self -from datafaker.generators.base import ( +from datafaker.proposers.base import ( Buckets, - Generator, - GeneratorFactory, NumericType, + Proposer, + ProposerFactory, dist_gen, get_column_type, ) from datafaker.utils import logger -class ContinuousDistributionGenerator(Generator): +class ContinuousDistributionGenerator(Proposer): """Base class for generators producing continuous distributions.""" expected_buckets: Sequence[NumericType] = [] @@ -133,7 +133,7 @@ def generate_data(self, count: int) -> list[Any]: ] -class ContinuousDistributionGeneratorFactory(GeneratorFactory): +class ContinuousDistributionProposerFactory(ProposerFactory): """All generators that want an average and standard deviation.""" def _get_generators_from_buckets( @@ -142,15 +142,15 @@ def _get_generators_from_buckets( table_name: str, column_name: str, buckets: Buckets, - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: return [ GaussianGenerator(table_name, column_name, buckets), UniformGenerator(table_name, column_name, buckets), ] - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] @@ -168,7 +168,7 @@ def get_generators( ) -class LogNormalGenerator(Generator): +class LogNormalGenerator(Proposer): """Generator producing numbers in a log-normal distribution.""" # R: @@ -266,7 +266,7 @@ def fit(self, default: float = -1) -> float: return self.buckets.fit_from_counts(self.expected_buckets) -class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): +class ContinuousLogDistributionProposerFactory(ContinuousDistributionProposerFactory): """All generators that want an average and standard deviation of log data.""" def _get_generators_from_buckets( @@ -275,7 +275,7 @@ def _get_generators_from_buckets( table_name: str, column_name: str, buckets: Buckets, - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: with engine.connect() as connection: result = connection.execute( text( @@ -298,7 +298,7 @@ def _get_generators_from_buckets( ] -class MultivariateNormalGenerator(Generator): +class MultivariateNormalGenerator(Proposer): """Generator of multiple values drawn from a multivariate normal distribution.""" # pylint: disable=too-many-arguments too-many-positional-arguments @@ -356,7 +356,7 @@ def fit(self, default: float = -1) -> float: return default -class MultivariateNormalGeneratorFactoryBase(GeneratorFactory): +class MultivariateNormalGeneratorFactoryBase(ProposerFactory): """Generator factory that makes distributions and maybe partitions.""" @abstractmethod @@ -592,7 +592,7 @@ def _middle_query(self, inner_query: str) -> str: ) -class MultivariateNormalGeneratorFactory(MultivariateNormalGeneratorFactoryBase): +class MultivariateNormalProposerFactory(MultivariateNormalGeneratorFactoryBase): """Normal distribution generator factory.""" def function_name(self) -> str: @@ -614,9 +614,9 @@ def query_comment(self) -> str: " normal distribution over the columns {columns}." ) - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators for these columns.""" # For the case of one column we'll use GaussianGenerator if len(columns) < 2: @@ -649,7 +649,7 @@ def get_generators( ] -class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): +class MultivariateLogNormalProposerFactory(MultivariateNormalProposerFactory): """Multivariate lognormal generator factory.""" def function_name(self) -> str: diff --git a/datafaker/generators/mimesis.py b/datafaker/proposers/mimesis.py similarity index 93% rename from datafaker/generators/mimesis.py rename to datafaker/proposers/mimesis.py index 78894ea6..14c0ca4f 100644 --- a/datafaker/generators/mimesis.py +++ b/datafaker/proposers/mimesis.py @@ -8,11 +8,11 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time -from datafaker.generators.base import ( +from datafaker.proposers.base import ( Buckets, - Generator, - GeneratorError, - GeneratorFactory, + Proposer, + ProposerError, + ProposerFactory, get_column_type, ) from datafaker.providers import DistributionProvider @@ -27,7 +27,7 @@ generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) -class MimesisGeneratorBase(Generator): +class MimesisGeneratorBase(Proposer): """Base class for a generator using Mimesis.""" def __init__( @@ -43,12 +43,12 @@ def __init__( f = generic for part in function_name.split("."): if not hasattr(f, part): - raise GeneratorError( + raise ProposerError( f"Mimesis does not have a function {function_name}: {part} not found" ) f = getattr(f, part) if not callable(f): - raise GeneratorError( + raise ProposerError( f"Mimesis object {function_name} is not a callable," " so cannot be used as a generator" ) @@ -184,7 +184,7 @@ def __init__( @classmethod def make_singleton( cls, column: Column, engine: Engine, function_name: str - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Make the appropriate generation configuration for this column.""" extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" @@ -255,7 +255,7 @@ def generate_data(self, count: int) -> list[Any]: ] -class MimesisStringGeneratorFactory(GeneratorFactory): +class MimesisStringProposerFactory(ProposerFactory): """All Mimesis generators that return strings.""" GENERATOR_NAMES = [ @@ -294,8 +294,8 @@ class MimesisStringGeneratorFactory(GeneratorFactory): def _get_generators_with( self, gen_class: Callable, **kwargs: Any - ) -> list[Generator]: - gens: list[Generator] = [] + ) -> list[Proposer]: + gens: list[Proposer] = [] for name in self.GENERATOR_NAMES: try: gens.append(gen_class(name, **kwargs)) @@ -303,9 +303,9 @@ def _get_generators_with( pass return gens - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] @@ -341,12 +341,12 @@ def get_generators( ) -class MimesisFloatGeneratorFactory(GeneratorFactory): +class MimesisFloatProposerFactory(ProposerFactory): """All Mimesis generators that return floating point numbers.""" - def get_generators( + def get_proposers( self, columns: list[Column], _engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] @@ -363,12 +363,12 @@ def get_generators( ) -class MimesisDateGeneratorFactory(GeneratorFactory): +class MimesisDateProposerFactory(ProposerFactory): """All Mimesis generators that return dates.""" - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] @@ -379,12 +379,12 @@ def get_generators( return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") -class MimesisDateTimeGeneratorFactory(GeneratorFactory): +class MimesisDateTimeProposerFactory(ProposerFactory): """All Mimesis generators that return datetimes.""" - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] @@ -397,12 +397,12 @@ def get_generators( ) -class MimesisTimeGeneratorFactory(GeneratorFactory): +class MimesisTimeProposerFactory(ProposerFactory): """All Mimesis generators that return times.""" - def get_generators( + def get_proposers( self, columns: list[Column], _engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] @@ -413,12 +413,12 @@ def get_generators( return [MimesisGenerator("datetime.time")] -class MimesisIntegerGeneratorFactory(GeneratorFactory): +class MimesisIntegerProposerFactory(ProposerFactory): """All Mimesis generators that return integers.""" - def get_generators( + def get_proposers( self, columns: list[Column], _engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] diff --git a/datafaker/generators/partitioned.py b/datafaker/proposers/partitioned.py similarity index 98% rename from datafaker/generators/partitioned.py rename to datafaker/proposers/partitioned.py index a7b9327f..60216c3b 100644 --- a/datafaker/generators/partitioned.py +++ b/datafaker/proposers/partitioned.py @@ -9,10 +9,10 @@ from sqlalchemy import Column, Connection, Engine, RowMapping, text from sqlalchemy.types import Integer, Numeric -from datafaker.generators.base import Generator, dist_gen, get_column_type -from datafaker.generators.continuous import ( +from datafaker.proposers.base import Proposer, dist_gen, get_column_type +from datafaker.proposers.continuous import ( CovariateQuery, - MultivariateNormalGeneratorFactory, + MultivariateNormalProposerFactory, ) from datafaker.utils import T, get_property, logger @@ -181,7 +181,7 @@ def __init__( ] + [f"{nc.column.name}: {nc.bitmask}" for nc in nullable_columns] -class NullPartitionedNormalGenerator(Generator): +class NullPartitionedNormalGenerator(Proposer): """ A generator of mixed numeric and non-numeric data. @@ -391,7 +391,7 @@ def __init__( self.nones[col_index] = None -class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): +class NullPartitionedNormalProposerFactory(MultivariateNormalProposerFactory): """Produces null partitioned generators, for complex interdependent data.""" SAMPLE_COUNT = MAXIMUM_CHOICES @@ -538,9 +538,9 @@ def _get_generator( ), ) - def get_generators( + def get_proposers( self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: + ) -> Sequence[Proposer]: """Get any appropriate generators for these columns.""" if len(columns) < 2: return [] @@ -548,7 +548,7 @@ def get_generators( if not nullable_columns: return [] table = columns[0].table.name - gens: list[Generator | None] = [] + gens: list[Proposer | None] = [] try: with engine.connect() as connection: cov_query = CovariateQuery(table, self) @@ -619,7 +619,7 @@ def _execute_partition_queries( return found_nonzero -class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): +class NullPartitionedLogNormalProposerFactory(NullPartitionedNormalProposerFactory): """ A generator for numeric and non-numeric columns. diff --git a/docs/source/duckdb.rst b/docs/source/duckdb.rst index 47004fd4..3098d19c 100644 --- a/docs/source/duckdb.rst +++ b/docs/source/duckdb.rst @@ -22,6 +22,13 @@ Or in Windows: set SRC_DSN=duckdb:///C:/path/to/file/duck.db set DST_DSN=duckdb:///C:/path/to/file/fake.db +Or in Windows PowerShell: + +.. code-block:: + + $env:SRC_DSN='duckdb:///C:/path/to/file/duck.db' + $env:DST_DSN='duckdb:///C:/path/to/file/fake.db' + This will use the DuckDB database in the file ``/path/to/file/duck.db`` and output to the file ``/path/to/file/fake.db``. Using Datafaker's ``create-tables`` command will create the new database file ``/path/to/file/fake.db``. diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py index 7ef2dbc3..7cd0a17d 100644 --- a/tests/test_interactive_generators.py +++ b/tests/test_interactive_generators.py @@ -6,9 +6,9 @@ from sqlalchemy import Connection, MetaData, select -from datafaker.generators.choice import ChoiceGeneratorFactory from datafaker.interactive.base import DbCmd from datafaker.interactive.generators import GeneratorCmd +from datafaker.proposers.choice import ChoiceProposerFactory from tests.utils import ( GeneratesDBTestCase, RequiresDBTestCase, @@ -593,8 +593,8 @@ class GeneratorsOutputTests(GeneratesDBTestCase): def setUp(self) -> None: super().setUp() - ChoiceGeneratorFactory.SAMPLE_COUNT = 500 - ChoiceGeneratorFactory.SUPPRESS_COUNT = 5 + ChoiceProposerFactory.SAMPLE_COUNT = 500 + ChoiceProposerFactory.SUPPRESS_COUNT = 5 def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: return TestGeneratorCmd( diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index e896b98f..06df29c8 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -6,8 +6,8 @@ from sqlalchemy import Connection, MetaData, insert, select -from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.interactive.base import DbCmd +from datafaker.proposers import NullPartitionedNormalProposerFactory from tests.test_interactive_generators import TestGeneratorCmd from tests.utils import GeneratesDBTestCase @@ -136,8 +136,8 @@ class NullPartitionedTests(GeneratesDBTestCase): def setUp(self) -> None: """Set up the test with specific sample and suppress counts.""" super().setUp() - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 - NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 + NullPartitionedNormalProposerFactory.SAMPLE_COUNT = 8 + NullPartitionedNormalProposerFactory.SUPPRESS_COUNT = 2 def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: """Get the configure-generators object as our command.""" @@ -425,7 +425,7 @@ def test_create_with_null_partitioned_grouped_sampled_tiny(self) -> None: """ # five will ensure that at least one group will have two elements in it, # but all three cannot. - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 + NullPartitionedNormalProposerFactory.SAMPLE_COUNT = 5 table_name = "observation" generate_count = 100 with self._get_cmd({}) as gc: From 09d2bec5ace50789452bebda21e0921ee74c34ee Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 10 Apr 2026 17:17:09 +0100 Subject: [PATCH 42/44] Fixes #88 --- datafaker/create.py | 18 ++++++++++--- tests/test_main.py | 62 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/datafaker/create.py b/datafaker/create.py index f91dbb0c..64782891 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,4 +1,5 @@ """Functions and classes to create and populate the target database.""" +import re from collections import Counter from pathlib import Path from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple @@ -33,6 +34,8 @@ Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None] RowCounts = Counter[str] +serial_re = re.compile(r"\bSERIAL\b") + @compiles(CreateColumn, "duckdb") def remove_serial(element: CreateColumn, compiler: Any, **kw: Any) -> str: @@ -48,23 +51,30 @@ def remove_serial(element: CreateColumn, compiler: Any, **kw: Any) -> str: :return: Corrected SQL. """ text: str = compiler.visit_create_column(element, **kw) - return text.replace(" SERIAL ", " INTEGER ") + return serial_re.sub("INTEGER", text) @compiles(CreateTable, "duckdb") def remove_on_delete_cascade(element: CreateTable, compiler: Any, **kw: Any) -> str: """ - Intercede in compilation for column creation, removing ``ON DELETE CASCADE``. + Intercede in compilation for column creation. DuckDB does not understand cascades, and we don't care about - that in datafaker. Ideally ``duckdb_engine`` would remove this for us. + that in datafaker so we remove ``ON DELETE CASCASE``. + + DuckDB does not understand ``SERIAL`` and we don't care + about autoincrementing, so we will replace it simply with + ``INTEGER``. + + Ideally ``duckdb_engine`` would remove these for us. :param element: The CreateTable being executed. :param compiler: Actually a DDLCompiler, but that type is not exported. :param kw: Further arguments. :return: Corrected SQL. """ text: str = compiler.visit_create_table(element, **kw) - return text.replace(" ON DELETE CASCADE", "") + t2 = serial_re.sub("INTEGER", text) + return t2.replace(" ON DELETE CASCADE", "") def create_db_tables(metadata: MetaData) -> None: diff --git a/tests/test_main.py b/tests/test_main.py index e81d2ea2..50dee7cf 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -10,7 +10,12 @@ from datafaker.main import app from datafaker.settings import Settings, SettingsError -from tests.utils import DatafakerTestCase, get_test_settings +from tests.utils import ( + DatafakerTestCase, + GeneratesDBTestCase, + TestDuckDb, + get_test_settings, +) runner = CliRunner(mix_stderr=False) @@ -499,3 +504,58 @@ def test_make_stats_with_force_enabled( self.load_yaml("stats_file.yaml"), {"some_stat": 0} ) self.assertSuccess(result) + + +class TestsCliCreate(GeneratesDBTestCase): + """Tests that use the CLI to generate output in the destination database.""" + + use_temporary_cwd = True + dst_schema_name = "fake.dstschema" + + def setUp(self) -> None: + """Set the runner with the environment variables we need.""" + super().setUp() + self.runner = CliRunner( + mix_stderr=False, + env={ + "src_dsn": self.dsn, + "dst_dsn": self.dst_dsn, + }, + ) + + def test_create_primary_key(self) -> None: + """Test the creation of a simple database with one primary key column.""" + orm = { + "tables": { + "tab": { + "columns": { + "col": { + "primary": True, + # For some reason, nullable primary keys triggers a + # weird corner case on DuckDB + "nullable": True, + "type": "INTEGER", + } + } + } + } + } + config: dict[str, Any] = {} + with Path("orm.yaml").open("w", encoding="utf-8") as fh: + fh.write(yaml.dump(orm)) + with Path("config.yaml").open("w", encoding="utf-8") as fh: + fh.write(yaml.dump(config)) + result = self.runner.invoke( + app, + [ + "create-tables", + ], + catch_exceptions=False, + ) + self.assertSuccess(result) + + +class TestCliCreateDuckDb(TestsCliCreate): + """Tests that use the CLI to generate output in a DuckDB database.""" + + database_type = TestDuckDb From 5a8df9d4e6b699c1ec82e0299ed4bcc8fb0ebe2c Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 13 Apr 2026 14:23:19 +0100 Subject: [PATCH 43/44] Rename "generator" to "proposer" in configure-generators --- datafaker/interactive/generators.py | 342 ++++++++++++++-------------- 1 file changed, 171 insertions(+), 171 deletions(-) diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py index 75251ea0..e5d228d8 100644 --- a/datafaker/interactive/generators.py +++ b/datafaker/interactive/generators.py @@ -22,24 +22,24 @@ @dataclass -class GeneratorInfo: +class ProposerInfo: """A generator and the columns it assigns to.""" columns: list[str] - gen: Proposer | None + proposer: Proposer | None @dataclass class GeneratorCmdTableEntry(TableEntry): """ - List of generators set for a table. + List of proposers set for a table. Includes the original setting and the currently configured - generators. + proposers. """ - old_generators: list[GeneratorInfo] - new_generators: list[GeneratorInfo] + old_proposers: list[ProposerInfo] + new_proposers: list[ProposerInfo] # pylint: disable=too-many-public-methods @@ -99,7 +99,7 @@ def make_table_entry( column_set = frozenset(columns) columns_assigned_so_far: set[str] = set() - new_generator_infos: list[GeneratorInfo] = [] + new_proposer_infos: list[ProposerInfo] = [] for gen_name, rg in get_row_generators(table_config): colset: set[str] = set(get_columns_assigned(rg)) for unknown in colset - column_set: @@ -117,32 +117,32 @@ def make_table_entry( ) actual_collist = [c for c in columns if c in colset] if actual_collist: - new_generator_infos.append( - GeneratorInfo( + new_proposer_infos.append( + ProposerInfo( columns=actual_collist.copy(), - gen=PredefinedProposer(table_name, rg, self.config), + proposer=PredefinedProposer(table_name, rg, self.config), ) ) columns_assigned_so_far |= colset - old_generator_infos = [ - GeneratorInfo(columns=gi.columns.copy(), gen=gi.gen) - for gi in new_generator_infos + old_proposer_infos = [ + ProposerInfo(columns=gi.columns.copy(), proposer=gi.proposer) + for gi in new_proposer_infos ] for colname in columns: if colname not in columns_assigned_so_far: - new_generator_infos.append( - GeneratorInfo( + new_proposer_infos.append( + ProposerInfo( columns=[colname], - gen=None, + proposer=None, ) ) - if len(new_generator_infos) == 0: + if len(new_proposer_infos) == 0: return None return GeneratorCmdTableEntry( name=table_name, - old_generators=old_generator_infos, - new_generators=new_generator_infos, + old_proposers=old_proposer_infos, + new_proposers=new_proposer_infos, ) def __init__( @@ -158,9 +158,9 @@ def __init__( :param config: Configuration loaded from ``config.yaml`` """ super().__init__(settings) - self.generators: list[Proposer] | None = None - self.generator_index = 0 - self.generators_valid_columns: Optional[tuple[int, list[str]]] = None + self.proposers: list[Proposer] | None = None + self.proposer_index = 0 + self.proposers_valid_columns: Optional[tuple[int, list[str]]] = None self.set_prompt() @property @@ -190,7 +190,7 @@ def _set_table_index(self, index: int) -> bool: """ ret = super()._set_table_index(index) if ret: - self.generator_index = 0 + self.proposer_index = 0 self.set_prompt() return ret @@ -209,7 +209,7 @@ def _previous_table(self) -> bool: self.table_index, ) return False - self.generator_index = len(table.new_generators) - 1 + self.proposer_index = len(table.new_proposers) - 1 else: self.print(self.ERROR_ALREADY_AT_START) return ret @@ -220,19 +220,19 @@ def get_table(self) -> GeneratorCmdTableEntry | None: return self.table_entries[self.table_index] return None - def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + def _get_table_and_proposer(self) -> tuple[str | None, ProposerInfo | None]: """Get a pair; the table name then the generator information.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] - if self.generator_index < len(entry.new_generators): - return (entry.name, entry.new_generators[self.generator_index]) + if self.proposer_index < len(entry.new_proposers): + return (entry.name, entry.new_proposers[self.proposer_index]) return (entry.name, None) return (None, None) def _get_column_names(self) -> list[str]: """Get the (unqualified) names for all the current columns.""" - (_, generator_info) = self._get_table_and_generator() - return generator_info.columns if generator_info else [] + (_, proposer_info) = self._get_table_and_proposer() + return proposer_info.columns if proposer_info else [] def _column_metadata(self) -> list[Column]: """Get the metadata for all the current columns.""" @@ -243,18 +243,18 @@ def _column_metadata(self) -> list[Column]: def set_prompt(self) -> None: """Set the prompt according to the current table, column and generator.""" - (table_name, gen_info) = self._get_table_and_generator() + (table_name, prop_info) = self._get_table_and_proposer() if table_name is None: self.prompt = "(generators) " return - if gen_info is None: + if prop_info is None: self.prompt = f"({table_name}) " return table = self.table_metadata() columns = [ - c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns + c + "[pk]" if table.columns[c].primary_key else c for c in prop_info.columns ] - gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" + gen = f" ({prop_info.proposer.name()})" if prop_info.proposer else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: @@ -272,10 +272,10 @@ def _copy_entries(self) -> None: for entry in self.table_entries: rgs = [] new_gens: list[Proposer] = [] - for generator in entry.new_generators: - if generator.gen is not None: - new_gens.append(generator.gen) - cqs = generator.gen.custom_queries() + for proposer in entry.new_proposers: + if proposer.proposer is not None: + new_gens.append(proposer.proposer) + cqs = proposer.proposer.custom_queries() for cq_key, cq in cqs.items(): src_stats.append( { @@ -285,10 +285,10 @@ def _copy_entries(self) -> None: } ) rg: dict[str, Any] = { - "name": generator.gen.function_name(), - "columns_assigned": generator.columns, + "name": proposer.proposer.function_name(), + "columns_assigned": proposer.columns, } - kwn = generator.gen.nominal_kwargs() + kwn = proposer.proposer.nominal_kwargs() if kwn: rg["kwargs"] = kwn rgs.append(rg) @@ -314,14 +314,14 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table_config) self.config["src-stats"] = src_stats - def _find_old_generator( + def _find_old_proposer( self, entry: GeneratorCmdTableEntry, columns: Iterable[str] ) -> Proposer | None: - """Find any generator that previously assigned to these exact same columns.""" + """Find any proposer that previously assigned to these exact same columns.""" fc = frozenset(columns) - for gen in entry.old_generators: + for gen in entry.old_proposers: if frozenset(gen.columns) == fc: - return gen.gen + return gen.proposer return None def do_quit(self, arg: str) -> bool: @@ -330,9 +330,9 @@ def do_quit(self, arg: str) -> bool: for entry in self.table_entries: header_shown = False g_entry = cast(GeneratorCmdTableEntry, entry) - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - new_gen = None if gen is None else gen.gen + for gen in g_entry.new_proposers: + old_gen = self._find_old_proposer(g_entry, gen.columns) + new_gen = None if gen is None else gen.proposer if old_gen != new_gen: if not header_shown: header_shown = True @@ -342,7 +342,7 @@ def do_quit(self, arg: str) -> bool: "...changing {0} from {1} to {2}", ", ".join(gen.columns), old_gen.name() if old_gen else "nothing", - gen.gen.name() if gen.gen else "nothing", + gen.proposer.name() if gen.proposer else "nothing", ) if count == 0: self.print("You have made no changes.") @@ -361,7 +361,7 @@ def do_tables(self, _arg: str) -> None: """List the tables.""" for t_entry in self.table_entries: entry = cast(GeneratorCmdTableEntry, t_entry) - gen_count = len(entry.new_generators) + gen_count = len(entry.new_proposers) how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" self.print("{0} ({1})", entry.name, how_many) @@ -372,17 +372,17 @@ def do_list(self, _arg: str) -> None: return g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) table = self.table_metadata() - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) + for gen in g_entry.new_proposers: + old_gen = self._find_old_proposer(g_entry, gen.columns) old = "" if old_gen is None else old_gen.name() - if old_gen == gen.gen: + if old_gen == gen.proposer: becomes = "" if old == "": old = "(not set)" - elif gen.gen is None: + elif gen.proposer is None: becomes = "(delete)" else: - becomes = f"->{gen.gen.name()}" + becomes = f"->{gen.proposer.name()}" primary = "" if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: primary = "[primary-key]" @@ -425,18 +425,18 @@ def _get_table_index(self, table_name: str) -> int | None: return n return None - def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + def _get_proposer_index(self, table_index: int, column_name: str) -> int | None: """ - Get the index number of a column within the list of generators in this table. + Get the index number of a column within the list of proposers in this table. :param table_index: The index of the table in which to search. :param column_name: The name of the column to search for. - :return: The index in the ``new_generators`` attribute of the table entry + :return: The index in the ``new_proposers`` attribute of the table entry containing the specified column, or None if this does not exist. """ entry = self.table_entries[table_index] - for n, gen in enumerate(entry.new_generators): - if column_name in gen.columns: + for n, prop in enumerate(entry.new_proposers): + if column_name in prop.columns: return n return None @@ -447,31 +447,31 @@ def go_to(self, target: str) -> bool: :return: True on success. """ (first_part, last_part) = split_column_full_name(target) - gen_index: int | None = None + prop_index: int | None = None if first_part: # target == table.column table_index = self._get_table_index(first_part) if table_index is None: self.print(self.ERROR_NO_SUCH_TABLE, first_part) return False - gen_index = self._get_generator_index(table_index, last_part) - if gen_index is None: + prop_index = self._get_proposer_index(table_index, last_part) + if prop_index is None: self.print(self.ERROR_NO_SUCH_COLUMN, last_part) return False else: # target == table or target == column table_index = self._get_table_index(last_part) - gen_index = 0 + prop_index = 0 if table_index is None: # not table, perhaps it's column - gen_index = self._get_generator_index(self.table_index, last_part) - if gen_index is None: + prop_index = self._get_proposer_index(self.table_index, last_part) + if prop_index is None: # it's neither self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, last_part) return False if table_index is not None: self._set_table_index(table_index) - self.generator_index = gen_index + self.proposer_index = prop_index self.set_prompt() return True @@ -502,11 +502,11 @@ def _go_next(self) -> None: if table is None: self.print("No more tables") return - next_gi = self.generator_index + 1 - if next_gi == len(table.new_generators): + next_gi = self.proposer_index + 1 + if next_gi == len(table.new_proposers): self.next_table(self.INFO_NO_MORE_TABLES) return - self.generator_index = next_gi + self.proposer_index = next_gi self.set_prompt() def complete_next( @@ -522,8 +522,8 @@ def complete_next( table_entry = self.table_entries[table_index] return [ f"{first_part}.{column}" - for gen in table_entry.new_generators - for column in gen.columns + for prop in table_entry.new_proposers + for column in prop.columns if column.startswith(last_part) ] # first_part is None, last_part might be table or column. @@ -539,8 +539,8 @@ def complete_next( if current_table: column_names = [ col - for gen in current_table.new_generators - for col in gen.columns + for prop in current_table.new_proposers + for col in prop.columns if col.startswith(last_part) ] else: @@ -549,39 +549,39 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous generator.""" - if self.generator_index == 0: + if self.proposer_index == 0: self._previous_table() else: - self.generator_index -= 1 + self.proposer_index -= 1 self.set_prompt() def do_b(self, arg: str) -> None: """Synonym for previous.""" self.do_previous(arg) - def _generators_valid(self) -> bool: - """Test if ``self.generators`` is still correct for the current columns.""" - return self.generators_valid_columns == ( + def _proposers_valid(self) -> bool: + """Test if ``self.proposers`` is still correct for the current columns.""" + return self.proposers_valid_columns == ( self.table_index, self._get_column_names(), ) - def _get_generator_proposals(self) -> list[Proposer]: - """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" - if not self._generators_valid(): - self.generators = None - if self.generators is None: + def _get_proposer_proposals(self) -> list[Proposer]: + """Get a list of acceptable proposers, sorted by decreasing fit to the actual data.""" + if not self._proposers_valid(): + self.proposers = None + if self.proposers is None: columns = self._column_metadata() - gens = everything_factory(self.config).get_proposers( + props = everything_factory(self.config).get_proposers( columns, self.sync_engine ) - sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) - self.generators = sorted_gens - self.generators_valid_columns = ( + sorted_props = sorted(props, key=lambda g: g.fit(9999)) + self.proposers = sorted_props + self.proposers_valid_columns = ( self.table_index, self._get_column_names().copy(), ) - return self.generators + return self.proposers def _print_privacy(self) -> None: """Print the privacy status of the current table.""" @@ -615,56 +615,56 @@ def do_compare(self, arg: str) -> None: for x in self._get_column_data(limit, to_str=str) ] } - gens: list[Proposer] = self._get_generator_proposals() + props: list[Proposer] = self._get_proposer_proposals() table_name = self.table_name() for argument in args: if argument.isdigit(): n = int(argument) - if 0 < n <= len(gens): - gen = gens[n - 1] - comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) - self._print_values_queried(table_name, n, gen) + if 0 < n <= len(props): + prop = props[n - 1] + comparison[f"{n}. {prop.name()}"] = prop.generate_data(limit) + self._print_values_queried(table_name, n, prop) self.print_table_by_columns(comparison) def do_c(self, arg: str) -> None: """Synonym for compare.""" self.do_compare(arg) - def _print_values_queried(self, table_name: str, n: int, gen: Proposer) -> None: + def _print_values_queried(self, table_name: str, n: int, prop: Proposer) -> None: """ - Print the values queried from the database for this generator. + Print the values queried from the database for this proposer. - :param table_name: The name of the table the generator applies to. + :param table_name: The name of the table the proposer applies to. :param n: A number to print at the start of the output. - :param gen: The generator to report. + :param gen: The proposer to report. """ - if not gen.select_aggregate_clauses() and not gen.custom_queries(): + if not prop.select_aggregate_clauses() and not prop.custom_queries(): self.print( "{0}. {1} requires no data from the source database.", n, - gen.name(), + prop.name(), ) else: self.print( "{0}. {1} requires the following data from the source database:", n, - gen.name(), + prop.name(), ) - self._print_select_aggregate_query(table_name, gen) - self._print_custom_queries(gen) + self._print_select_aggregate_query(table_name, prop) + self._print_custom_queries(prop) - def _print_custom_queries(self, gen: Proposer) -> None: + def _print_custom_queries(self, prop: Proposer) -> None: """ Print all the custom queries and all the values they get in this case. - :param gen: The generator to print the custom queries for. + :param prop: The proposer to print the custom queries for. """ - cqs = gen.custom_queries() + cqs = prop.custom_queries() if not cqs: return cq_key2args: dict[str, Any] = {} - nominal = gen.nominal_kwargs() - actual = gen.actual_kwargs() + nominal = prop.nominal_kwargs() + actual = prop.actual_kwargs() self._get_custom_queries_from( cq_key2args, nominal, @@ -710,21 +710,21 @@ def _get_aggregate_query(self, gens: list[Proposer], table_name: str) -> str | N return None return f"SELECT {', '.join(clauses)} FROM {table_name}" - def _print_select_aggregate_query(self, table_name: str, gen: Proposer) -> None: + def _print_select_aggregate_query(self, table_name: str, prop: Proposer) -> None: """ Print the select aggregate query and all the values it gets in this case. This is not the entire query that will be executed, but only the part of it - that is required by a certain generator. + that is required by a certain proposer. :param table_name: The table name. - :param gen: The generator to limit the aggregate query to. + :param prop: The proposer to limit the aggregate query to. """ - sacs = gen.select_aggregate_clauses() + sacs = prop.select_aggregate_clauses() if not sacs: return - kwa = gen.actual_kwargs() + kwa = prop.actual_kwargs() vals = [] - src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} + src_stat2kwarg = {v: k for k, v in prop.nominal_kwargs().items()} for n in sacs.keys(): src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' if src_stat in src_stat2kwarg: @@ -733,7 +733,7 @@ def _print_select_aggregate_query(self, table_name: str, gen: Proposer) -> None: vals.append(kwa[ak]) else: logger.warning( - "actual_kwargs for %s does not report %s", gen.name(), ak + "actual_kwargs for %s does not report %s", prop.name(), ak ) else: logger.warning( @@ -741,11 +741,11 @@ def _print_select_aggregate_query(self, table_name: str, gen: Proposer) -> None: "nominal_kwargs for %s does not have a value" ' SRC_STATS["auto__%s"]["results"][0]["%s"]' ), - gen.name(), + prop.name(), table_name, n, ) - select_q = self._get_aggregate_query([gen], table_name) + select_q = self._get_aggregate_query([prop], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) def _get_column_data( @@ -772,17 +772,17 @@ def do_propose(self, _arg: str) -> None: the column and against each other) with the 'compare' command. """ limit = 5 - gens = self._get_generator_proposals() + props = self._get_proposer_proposals() sample = self._get_column_data(limit) if sample: rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) else: self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) - if not gens: + if not props: self.print(self.PROPOSE_NOTHING) - for index, gen in enumerate(gens): - fit = gen.fit(-1) + for index, prop in enumerate(props): + fit = prop.fit(-1) if fit == -1: fit_s = "(no fit)" elif fit < 100: @@ -792,28 +792,28 @@ def do_propose(self, _arg: str) -> None: self.print( self.PROPOSE_GENERATOR_SAMPLE_TEXT, index=index + 1, - name=gen.name(), + name=prop.name(), fit=fit_s, - sample="; ".join(map(repr, gen.generate_data(limit))), + sample="; ".join(map(repr, prop.generate_data(limit))), ) def do_p(self, arg: str) -> None: """Synonym for propose.""" self.do_propose(arg) - def get_proposed_generator_by_name(self, gen_name: str) -> Proposer | None: - """Find a generator by name from the list of proposals.""" - for gen in self._get_generator_proposals(): + def get_proposer_by_name(self, gen_name: str) -> Proposer | None: + """Find a proposer by name from the list of proposals.""" + for gen in self._get_proposer_proposals(): if gen.name() == gen_name: return gen return None def do_set(self, arg: str) -> None: """Set one of the proposals as a generator.""" - if arg.isdigit() and not self._generators_valid(): + if arg.isdigit() and not self._proposers_valid(): self.print("Please run 'propose' before 'set '") return - gens = self._get_generator_proposals() + gens = self._get_proposer_proposals() new_gen: Proposer | None if arg.isdigit(): index = int(arg) @@ -828,23 +828,23 @@ def do_set(self, arg: str) -> None: return new_gen = gens[index - 1] else: - new_gen = self.get_proposed_generator_by_name(arg) + new_gen = self.get_proposer_by_name(arg) if new_gen is None: self.print("'{0}' is not an appropriate generator for this column", arg) return - self.set_generator(new_gen) + self.set_proposer(new_gen) self._go_next() - def set_generator(self, gen: Proposer | None) -> None: + def set_proposer(self, prop: Proposer | None) -> None: """Set the current column's generator.""" - (table, gen_info) = self._get_table_and_generator() + (table, gen_info) = self._get_table_and_proposer() if table is None: self.print("Error: no table") return if gen_info is None: self.print("Error: no column") return - gen_info.gen = gen + gen_info.proposer = prop def do_s(self, arg: str) -> None: """Synonym for set.""" @@ -852,7 +852,7 @@ def do_s(self, arg: str) -> None: def do_unset(self, _arg: str) -> None: """Remove any generator set for this column.""" - self.set_generator(None) + self.set_proposer(None) self._go_next() def merge_columns(self, arg: str) -> bool: @@ -870,7 +870,7 @@ def merge_columns(self, arg: str) -> bool: return False cols_available = functools.reduce( lambda x, y: x | y, - [frozenset(gen.columns) for gen in table_entry.new_generators], + [frozenset(gen.columns) for gen in table_entry.new_proposers], ) cols_to_merge = frozenset(cols) unknown_cols = cols_to_merge - cols_available @@ -878,22 +878,22 @@ def merge_columns(self, arg: str) -> bool: for uc in unknown_cols: self.print(self.ERROR_NO_SUCH_COLUMN, uc) return False - gen_info = table_entry.new_generators[self.generator_index] + gen_info = table_entry.new_proposers[self.proposer_index] stated_current_columns = cols_to_merge & frozenset(gen_info.columns) if stated_current_columns: for c in stated_current_columns: self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) return False - # Remove cols_to_merge from each generator - new_new_generators: list[GeneratorInfo] = [] - for gen in table_entry.new_generators: + # Remove cols_to_merge from each proposer + new_new_proposers: list[ProposerInfo] = [] + for gen in table_entry.new_proposers: if gen is gen_info: - # Add columns to this generator - self.generator_index = len(new_new_generators) - new_new_generators.append( - GeneratorInfo( + # Add columns to this proposer + self.proposer_index = len(new_new_proposers) + new_new_proposers.append( + ProposerInfo( columns=gen.columns + cols, - gen=None, + proposer=None, ) ) else: @@ -901,14 +901,14 @@ def merge_columns(self, arg: str) -> bool: new_columns = [c for c in gen.columns if c not in cols_to_merge] is_changed = len(new_columns) != len(gen.columns) if new_columns: - # We have not removed this generator completely - new_new_generators.append( - GeneratorInfo( + # We have not removed this proposer completely + new_new_proposers.append( + ProposerInfo( columns=new_columns, - gen=None if is_changed else gen.gen, + proposer=None if is_changed else gen.proposer, ) ) - table_entry.new_generators = new_new_generators + table_entry.new_proposers = new_new_proposers self.set_prompt() return True @@ -926,8 +926,8 @@ def complete_merge( return [] return [ column - for i, gen in enumerate(table_entry.new_generators) - if i != self.generator_index + for i, gen in enumerate(table_entry.new_proposers) + if i != self.proposer_index for column in gen.columns if column.startswith(last_arg) ] @@ -941,8 +941,8 @@ def do_unmerge(self, arg: str) -> None: if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) + prop_info = table_entry.new_proposers[self.proposer_index] + current_columns = frozenset(prop_info.columns) cols_to_unmerge = frozenset(cols) unknown_cols = cols_to_unmerge - current_columns if unknown_cols: @@ -959,15 +959,15 @@ def do_unmerge(self, arg: str) -> None: return # Remove unmerged columns for um in cols_to_unmerge: - gen_info.columns.remove(um) - # The existing generator will not work - gen_info.gen = None - # And put them into a new (empty) generator - table_entry.new_generators.insert( - self.generator_index + 1, - GeneratorInfo( + prop_info.columns.remove(um) + # The existing proposer will not work + prop_info.proposer = None + # And put them into a new (empty) proposer + table_entry.new_proposers.insert( + self.proposer_index + 1, + ProposerInfo( columns=cols, - gen=None, + proposer=None, ), ) self.set_prompt() @@ -982,7 +982,7 @@ def complete_unmerge( return [] return [ column - for column in table_entry.new_generators[self.generator_index].columns + for column in table_entry.new_proposers[self.proposer_index].columns if column.startswith(last_arg) ] @@ -991,8 +991,8 @@ def get_current_columns(self) -> set[str]: table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: return set() - gen_info = table_entry.new_generators[self.generator_index] - return set(gen_info.columns) + prop_info = table_entry.new_proposers[self.proposer_index] + return set(prop_info.columns) def set_merged_columns(self, first_col: str, other_cols: str) -> bool: """ @@ -1011,17 +1011,17 @@ def set_merged_columns(self, first_col: str, other_cols: str) -> bool: return self.merge_columns(other_cols) -def try_setting_generator(gc: GeneratorCmd, gens: Iterable[str]) -> bool: +def try_setting_generator(gc: GeneratorCmd, proposers: Iterable[str]) -> bool: """ - Set the current generator by name if possible. + Set the current proposer by name if possible. :param gc: The interactive ``GeneratorCmd`` to use. - :param gens: A list of names of generators to try, in order. - :return: True if one of the generators was successfully set, False otherwise. + :param proposers: A list of names of proposers to try, in order. + :return: True if one of the proposers was successfully set, False otherwise. """ - for gen in gens: - new_gen = gc.get_proposed_generator_by_name(gen) - if new_gen is not None: - gc.set_generator(new_gen) + for prop in proposers: + new_prop = gc.get_proposer_by_name(prop) + if new_prop is not None: + gc.set_proposer(new_prop) return True return False From 756009182c976ecbd60654b61dea25279690149c Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 13 Apr 2026 18:47:49 +0100 Subject: [PATCH 44/44] More symbol changes generator->proposer --- datafaker/proposers/choice.py | 40 +++++++++++++++--------------- datafaker/proposers/continuous.py | 24 +++++++++--------- datafaker/proposers/mimesis.py | 30 +++++++++++----------- datafaker/proposers/partitioned.py | 6 ++--- 4 files changed, 50 insertions(+), 50 deletions(-) diff --git a/datafaker/proposers/choice.py b/datafaker/proposers/choice.py index 13fbbb78..2d9b6340 100644 --- a/datafaker/proposers/choice.py +++ b/datafaker/proposers/choice.py @@ -44,8 +44,8 @@ def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None yield x -class ChoiceGenerator(Proposer): - """Base generator for all generators producing choices of items.""" +class ChoiceProposer(Proposer): + """Base proposer for all proposers producing choices of items.""" STORE_COUNTS = False @@ -59,7 +59,7 @@ def __init__( sample_count: int | None = None, suppress_count: int = 0, ) -> None: - """Initialise a ChoiceGenerator.""" + """Initialise a ChoiceProposer.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -167,7 +167,7 @@ def fit(self, default: float = -1) -> float: return default if self._fit is None else self._fit -class ZipfChoiceGenerator(ChoiceGenerator): +class ZipfChoiceProposer(ChoiceProposer): """Generator producing items in a Zipf distribution.""" def get_estimated_counts(self, counts: list[int]) -> list[int]: @@ -201,8 +201,8 @@ def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, N yield p -class UniformChoiceGenerator(ChoiceGenerator): - """A generator producing values, each roughly as frequently as each other.""" +class UniformChoiceProposer(ChoiceProposer): + """A proposer producing values, each roughly as frequently as each other.""" def get_estimated_counts(self, counts: list[int]) -> list[int]: """Get the counts that we would expect if this distribution was the correct one.""" @@ -217,8 +217,8 @@ def generate_data(self, count: int) -> list[Any]: return [dist_gen.choice_direct(self.values) for _ in range(count)] -class WeightedChoiceGenerator(ChoiceGenerator): - """Choice generator that matches the source data's frequency.""" +class WeightedChoiceProposer(ChoiceProposer): + """Choice proposer that matches the source data's frequency.""" STORE_COUNTS = True @@ -315,33 +315,33 @@ def get_proposers( vg = ValueGatherer(results, self.SUPPRESS_COUNT) if vg.counts: generators += [ - ZipfChoiceGenerator( + ZipfChoiceProposer( table_name, column_name, vg.values, vg.counts ), - UniformChoiceGenerator( + UniformChoiceProposer( table_name, column_name, vg.values, vg.counts ), - WeightedChoiceGenerator( + WeightedChoiceProposer( table_name, column_name, vg.cvs, vg.counts ), ] if vg.counts_not_suppressed: generators += [ - ZipfChoiceGenerator( + ZipfChoiceProposer( table_name, column_name, vg.values_not_suppressed, vg.counts_not_suppressed, suppress_count=self.SUPPRESS_COUNT, ), - UniformChoiceGenerator( + UniformChoiceProposer( table_name, column_name, vg.values_not_suppressed, vg.counts_not_suppressed, suppress_count=self.SUPPRESS_COUNT, ), - WeightedChoiceGenerator( + WeightedChoiceProposer( table_name=table_name, column_name=column_name, values=vg.cvs_not_suppressed, @@ -361,21 +361,21 @@ def get_proposers( vg = ValueGatherer(sampled_results, self.SUPPRESS_COUNT) if vg.counts: generators += [ - ZipfChoiceGenerator( + ZipfChoiceProposer( table_name, column_name, vg.values, vg.counts, sample_count=self.SAMPLE_COUNT, ), - UniformChoiceGenerator( + UniformChoiceProposer( table_name, column_name, vg.values, vg.counts, sample_count=self.SAMPLE_COUNT, ), - WeightedChoiceGenerator( + WeightedChoiceProposer( table_name, column_name, vg.cvs, @@ -385,7 +385,7 @@ def get_proposers( ] if vg.counts_not_suppressed: generators += [ - ZipfChoiceGenerator( + ZipfChoiceProposer( table_name, column_name, vg.values_not_suppressed, @@ -393,7 +393,7 @@ def get_proposers( sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), - UniformChoiceGenerator( + UniformChoiceProposer( table_name, column_name, vg.values_not_suppressed, @@ -401,7 +401,7 @@ def get_proposers( sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), - WeightedChoiceGenerator( + WeightedChoiceProposer( table_name=table_name, column_name=column_name, values=vg.cvs_not_suppressed, diff --git a/datafaker/proposers/continuous.py b/datafaker/proposers/continuous.py index ce2352a7..88c63d16 100644 --- a/datafaker/proposers/continuous.py +++ b/datafaker/proposers/continuous.py @@ -21,13 +21,13 @@ from datafaker.utils import logger -class ContinuousDistributionGenerator(Proposer): +class ContinuousDistributionProposer(Proposer): """Base class for generators producing continuous distributions.""" expected_buckets: Sequence[NumericType] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): - """Initialise a ContinuousDistributionGenerator.""" + """Initialise a ContinuousDistributionProposer.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -77,7 +77,7 @@ def fit(self, default: float = -1) -> float: return self.buckets.fit_from_counts(self.expected_buckets) -class GaussianGenerator(ContinuousDistributionGenerator): +class GaussianProposer(ContinuousDistributionProposer): """Generator producing numbers in a Gaussian (normal) distribution.""" expected_buckets = [ @@ -105,7 +105,7 @@ def generate_data(self, count: int) -> list[Any]: ] -class UniformGenerator(ContinuousDistributionGenerator): +class UniformProposer(ContinuousDistributionProposer): """Generator producing numbers in a uniform distribution.""" expected_buckets = [ @@ -144,8 +144,8 @@ def _get_generators_from_buckets( buckets: Buckets, ) -> Sequence[Proposer]: return [ - GaussianGenerator(table_name, column_name, buckets), - UniformGenerator(table_name, column_name, buckets), + GaussianProposer(table_name, column_name, buckets), + UniformProposer(table_name, column_name, buckets), ] def get_proposers( @@ -168,7 +168,7 @@ def get_proposers( ) -class LogNormalGenerator(Proposer): +class LogNormalProposer(Proposer): """Generator producing numbers in a log-normal distribution.""" # R: @@ -199,7 +199,7 @@ def __init__( logmean: float, logstddev: float, ): - """Initialise a LogNormalGenerator.""" + """Initialise a LogNormalProposer.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -288,7 +288,7 @@ def _get_generators_from_buckets( if result is None or result.logstddev is None: return [] return [ - LogNormalGenerator( + LogNormalProposer( table_name, column_name, buckets, @@ -298,7 +298,7 @@ def _get_generators_from_buckets( ] -class MultivariateNormalGenerator(Proposer): +class MultivariateNormalProposer(Proposer): """Generator of multiple values drawn from a multivariate normal distribution.""" # pylint: disable=too-many-arguments too-many-positional-arguments @@ -310,7 +310,7 @@ def __init__( covariates: RowMapping, function_name: str, ) -> None: - """Initialise a MultivariateNormalGenerator.""" + """Initialise a MultivariateNormalProposer.""" self._table = table_name self._columns = column_names self._query = query @@ -639,7 +639,7 @@ def get_proposers( if not covariates or covariates["c0_0"] is None: return [] return [ - MultivariateNormalGenerator( + MultivariateNormalProposer( table, column_names, query, diff --git a/datafaker/proposers/mimesis.py b/datafaker/proposers/mimesis.py index 14c0ca4f..3fa1bdc3 100644 --- a/datafaker/proposers/mimesis.py +++ b/datafaker/proposers/mimesis.py @@ -27,15 +27,15 @@ generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) -class MimesisGeneratorBase(Proposer): - """Base class for a generator using Mimesis.""" +class MimesisProposerBase(Proposer): + """Base class for a proposer using Mimesis.""" def __init__( self, function_name: str, ): """ - Initialise a generator that uses Mimesis. + Initialise a proposer that uses Mimesis. :param function_name: is relative to 'generic', for example 'person.name'. """ @@ -64,7 +64,7 @@ def generate_data(self, count: int) -> list[Any]: return [self._generator_function() for _ in range(count)] -class MimesisGenerator(MimesisGeneratorBase): +class MimesisProposer(MimesisProposerBase): """A generator using Mimesis.""" def __init__( @@ -108,7 +108,7 @@ def fit(self, default: float = -1) -> float: return default if self._fit is None else self._fit -class MimesisGeneratorTruncated(MimesisGenerator): +class MimesisGeneratorTruncated(MimesisProposer): """A string generator using Mimesis that must fit within a certain number of characters.""" def __init__( @@ -151,8 +151,8 @@ def generate_data(self, count: int) -> list[Any]: return [self._generator_function()[: self._length] for _ in range(count)] -class MimesisDateTimeGenerator(MimesisGeneratorBase): - """DateTime generator using Mimesis.""" +class MimesisDateTimeProposer(MimesisProposerBase): + """DateTime proposer using Mimesis.""" # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( @@ -165,7 +165,7 @@ def __init__( end: int, ) -> None: """ - Initialise a MimesisDateTimeGenerator. + Initialise a MimesisDateTimeProposer. :param column: The column to generate into :param function_name: The name of the mimesis function @@ -198,7 +198,7 @@ def make_singleton( if result is None or result.start is None or result.end is None: return [] return [ - MimesisDateTimeGenerator( + MimesisDateTimeProposer( column, function_name, min_year, @@ -335,7 +335,7 @@ def get_proposers( buckets=buckets, ) return self._get_generators_with( - MimesisGenerator, + MimesisProposer, value_fn=fitness_fn, buckets=buckets, ) @@ -355,7 +355,7 @@ def get_proposers( return [] return list( map( - MimesisGenerator, + MimesisProposer, [ "person.height", ], @@ -376,7 +376,7 @@ def get_proposers( ct = get_column_type(column) if not isinstance(ct, Date): return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") + return MimesisDateTimeProposer.make_singleton(column, engine, "datetime.date") class MimesisDateTimeProposerFactory(ProposerFactory): @@ -392,7 +392,7 @@ def get_proposers( ct = get_column_type(column) if not isinstance(ct, DateTime): return [] - return MimesisDateTimeGenerator.make_singleton( + return MimesisDateTimeProposer.make_singleton( column, engine, "datetime.datetime" ) @@ -410,7 +410,7 @@ def get_proposers( ct = get_column_type(column) if not isinstance(ct, Time): return [] - return [MimesisGenerator("datetime.time")] + return [MimesisProposer("datetime.time")] class MimesisIntegerProposerFactory(ProposerFactory): @@ -426,4 +426,4 @@ def get_proposers( ct = get_column_type(column) if not isinstance(ct, Numeric) and not isinstance(ct, Integer): return [] - return [MimesisGenerator("person.weight")] + return [MimesisProposer("person.weight")] diff --git a/datafaker/proposers/partitioned.py b/datafaker/proposers/partitioned.py index 60216c3b..cbecc999 100644 --- a/datafaker/proposers/partitioned.py +++ b/datafaker/proposers/partitioned.py @@ -181,7 +181,7 @@ def __init__( ] + [f"{nc.column.name}: {nc.bitmask}" for nc in nullable_columns] -class NullPartitionedNormalGenerator(Proposer): +class NullPartitionedNormalProposer(Proposer): """ A generator of mixed numeric and non-numeric data. @@ -495,7 +495,7 @@ def _get_generator( columns: list[Column], nullable_columns: list[NullableColumn], name_suffix: str | None = None, - ) -> NullPartitionedNormalGenerator | None: + ) -> NullPartitionedNormalProposer | None: where = "" if 1 < cov_query.suppress_count: where = f' WHERE {cov_query.suppress_count} < "count"' @@ -523,7 +523,7 @@ def _get_generator( if not self._execute_partition_queries(connection, partitions): return None query = self.get_partition_count_query(nullable_columns, cov_query.table, where) - return NullPartitionedNormalGenerator( + return NullPartitionedNormalProposer( f"{cov_query.table}__{columns[0].name}", partitions, self.function_name(),