Skip to content

utils

utils

Shared utilities for Safe Synthesizer.

Provides schema prompt creation, statistics logging, file I/O helpers, data loading, and general-purpose functions used across the pipeline.

Functions:

Name Description
create_schema_prompt

Create the schema prompt from column names and a template.

get_random_number_generator

Return a random number generator with the given seed.

log_stats

Log aggregated statistics as a structured table.

log_training_example_stats

Log training example statistics from the given dictionary.

round_number_if_float

Round the number to the given precision if it is a float.

smart_read_table

Load tabular data from a file path, or return an existing DataFrame.

time_function

Decorator to log the time taken by a function to execute.

grouped_train_test_split

Split a HuggingFace Dataset preserving group membership.

debug_fmt

Format dataframes for the purposes of data actions debugging.

merge_dicts

Deep-merge two dicts, preferring values from new on conflict.

is_iterable

Check whether x has both __iter__ and __getitem__.

flatten

Flatten a possibly nested iterable.

all_equal_type

Check whether every element in an iterable is an instance of type_.

write_json

Write a dictionary to a JSON file, creating parent directories as needed.

load_json

Load JSON file and return the content as a dict.

create_schema_prompt(columns, instruction, prompt_template, prefill='', exclude_columns=None)

Create the schema prompt from column names and a template.

Parameters:

Name Type Description Default
columns list[str]

List of column names to include in the schema.

required
instruction str

Instruction text placed before the schema.

required
prompt_template str

Template string with {instruction}, {schema}, and {prefill} placeholders.

required
prefill str

Optional text appended after the schema.

''
exclude_columns list[str] | None

Column names to omit from the schema.

None

Returns:

Type Description
str

The formatted prompt string.

Source code in src/nemo_safe_synthesizer/utils.py
def create_schema_prompt(
    columns: list[str],
    instruction: str,
    prompt_template: str,
    prefill: str = "",
    exclude_columns: list[str] | None = None,
) -> str:
    """Create the schema prompt from column names and a template.

    Args:
        columns: List of column names to include in the schema.
        instruction: Instruction text placed before the schema.
        prompt_template: Template string with ``{instruction}``, ``{schema}``,
            and ``{prefill}`` placeholders.
        prefill: Optional text appended after the schema.
        exclude_columns: Column names to omit from the schema.

    Returns:
        The formatted prompt string.
    """
    exclude_set = set(exclude_columns or [])
    return prompt_template.format(
        instruction=instruction,
        schema=",".join([f'"{c}":<unk>' for c in columns if c not in exclude_set]),
        prefill=prefill,
    )

get_random_number_generator(seed)

Return a random number generator with the given seed.

Source code in src/nemo_safe_synthesizer/utils.py
def get_random_number_generator(seed: int) -> np.random.Generator:
    """Return a random number generator with the given seed."""
    return np.random.default_rng(seed)

log_stats(stats, headers=None, title=None)

Log aggregated statistics as a structured table.

Console output is rendered as a Rich ASCII table by the structlog processor; JSON logs receive structured key/value pairs.

Parameters:

Name Type Description Default
stats Statistics | list[Statistics]

One or more Statistics objects.

required
headers list[str] | None

Column headers (one per Statistics object).

None
title str | None

Optional table title.

None
Source code in src/nemo_safe_synthesizer/utils.py
def log_stats(
    stats: Statistics | list[Statistics],
    headers: list[str] | None = None,
    title: str | None = None,
) -> None:
    """Log aggregated statistics as a structured table.

    Console output is rendered as a Rich ASCII table by the structlog
    processor; JSON logs receive structured key/value pairs.

    Args:
        stats: One or more ``Statistics`` objects.
        headers: Column headers (one per ``Statistics`` object).
        title: Optional table title.
    """
    headers = headers or []
    stats = stats if isinstance(stats, list) else [stats]

    # Build structured data - processor will render as table for console
    structured_stats = {}
    for header, stat in zip(headers, stats):
        key = header.lower().replace(" ", "_")
        structured_stats[key] = {
            "min": round_number_if_float(stat.min),
            "max": round_number_if_float(stat.max),
            "mean": round_number_if_float(stat.mean),
            "stddev": round_number_if_float(stat.stddev),
        }

    stats_title = title or "Statistics"

    # Pass structured data - processor renders for console, JSON keeps as-is
    logger.info(
        "",
        extra={
            "ctx": {
                "render_table": True,
                "tabular_data": structured_stats,
                "title": stats_title,
            }
        },
    )

log_training_example_stats(stats_dict, **kwargs)

Log training example statistics from the given dictionary.

Source code in src/nemo_safe_synthesizer/utils.py
def log_training_example_stats(stats_dict: dict[str, Statistics], **kwargs) -> None:
    """Log training example statistics from the given dictionary."""
    stats = list(stats_dict.values())
    headers = list([name.replace("_", " ").capitalize() for name in stats_dict.keys()])
    log_stats(title="Training Example Statistics", stats=stats, headers=headers, **kwargs)

round_number_if_float(number, precision=3)

Round the number to the given precision if it is a float.

Source code in src/nemo_safe_synthesizer/utils.py
def round_number_if_float(number, precision=3):
    """Round the number to the given precision if it is a float."""
    return round(number, precision) if isinstance(number, float) else number

smart_read_table(df_or_path)

Load tabular data from a file path, or return an existing DataFrame.

Supported formats: CSV, JSON, JSONL, and Parquet.

Parameters:

Name Type Description Default
df_or_path str | Path | DataFrame

A DataFrame (returned as-is), or a path to a .csv, .json, .jsonl, or .parquet file.

required

Returns:

Type Description
DataFrame

The loaded (or passed-through) DataFrame.

Raises:

Type Description
ValueError

If the file extension is not supported.

Source code in src/nemo_safe_synthesizer/utils.py
def smart_read_table(df_or_path: str | Path | pd.DataFrame) -> pd.DataFrame:
    """Load tabular data from a file path, or return an existing DataFrame.

    Supported formats: CSV, JSON, JSONL, and Parquet.

    Args:
        df_or_path: A ``DataFrame`` (returned as-is), or a path to a
            ``.csv``, ``.json``, ``.jsonl``, or ``.parquet`` file.

    Returns:
        The loaded (or passed-through) ``DataFrame``.

    Raises:
        ValueError: If the file extension is not supported.
    """
    if isinstance(df_or_path, pd.DataFrame):
        return df_or_path

    path = str(df_or_path)
    if path.endswith(".csv"):
        df = pd.read_csv(path)
    elif path.endswith(".json"):
        try:
            df = pd.read_json(path)
        except Exception:
            df = pd.read_json(path, lines=True, orient="records")
    elif path.endswith(".jsonl"):
        df = pd.read_json(path, lines=True, orient="records")
    elif path.endswith(".parquet"):
        df = pd.read_parquet(path)
    else:
        raise ValueError(f"Unsupported file type: {path}")
    return df

time_function(func)

Decorator to log the time taken by a function to execute.

Source code in src/nemo_safe_synthesizer/utils.py
def time_function(func):
    """Decorator to log the time taken by a function to execute."""

    @functools.wraps(func)
    def time_closure(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        time_elapsed = time.perf_counter() - start
        time_elapsed = f"{time_elapsed:.2f} sec" if time_elapsed <= 120 else f"{time_elapsed / 60:.2f} min"
        logger.info(f"⏱️  Function: {func.__name__}, Time: {time_elapsed}\n")
        return result

    return time_closure

grouped_train_test_split(dataset, test_size, group_by, seed=None)

Split a HuggingFace Dataset preserving group membership.

Currently unused. Converts the dataset to a pandas DataFrame and delegates to holdout.grouped_train_test_split.

Parameters:

Name Type Description Default
dataset Dataset

The HuggingFace Dataset to split.

required
test_size float

Fraction or absolute number of test rows.

required
group_by str | list[str]

Column name or list of column names defining groups.

required
seed int | None

Random state for reproducibility.

None

Returns:

Type Description
DataFrame

Tuple of (train_df, test_df), or (train_df, None) on

DataFrame | None

failure.

Source code in src/nemo_safe_synthesizer/utils.py
def grouped_train_test_split(
    dataset: Dataset,
    test_size: float,
    group_by: str | list[str],
    seed: int | None = None,
) -> tuple[DataFrame, DataFrame | None]:
    """Split a HuggingFace Dataset preserving group membership.

    Currently unused. Converts the dataset to a pandas ``DataFrame`` and
    delegates to ``holdout.grouped_train_test_split``.

    Args:
        dataset: The HuggingFace ``Dataset`` to split.
        test_size: Fraction or absolute number of test rows.
        group_by: Column name or list of column names defining groups.
        seed: Random state for reproducibility.

    Returns:
        Tuple of ``(train_df, test_df)``, or ``(train_df, None)`` on
        failure.
    """
    # Convert to pandas for group operations
    df = dataset.to_pandas()
    # importing like this to avoid a dep for testing on the sdk side
    from .holdout import holdout as nss_holdout

    return nss_holdout.grouped_train_test_split(df=df, test_size=test_size, group_by=group_by, random_state=seed)

debug_fmt(df)

Format dataframes for the purposes of data actions debugging.

Source code in src/nemo_safe_synthesizer/utils.py
def debug_fmt(df: pd.DataFrame) -> str:
    """Format dataframes for the purposes of data actions debugging."""
    return df.head(5).to_json(orient="records", date_format="iso")

merge_dicts(base, new)

Deep-merge two dicts, preferring values from new on conflict.

Source code in src/nemo_safe_synthesizer/utils.py
def merge_dicts(base: dict, new: dict) -> dict:
    """Deep-merge two dicts, preferring values from ``new`` on conflict."""
    result = base.copy()
    for k, new_v in new.items():
        base_v = result.get(k)
        if isinstance(base_v, dict) and isinstance(new_v, dict):
            result[k] = merge_dicts(base_v, new_v)
        else:
            result[k] = new_v
    return result

is_iterable(x)

Check whether x has both __iter__ and __getitem__.

Source code in src/nemo_safe_synthesizer/utils.py
def is_iterable(x: Any):
    """Check whether ``x`` has both ``__iter__`` and ``__getitem__``."""
    return hasattr(x, "__iter__") and hasattr(x, "__getitem__")

flatten(iter)

Flatten a possibly nested iterable.

Strings are yielded as-is (not broken into characters). Dicts are yielded whole with a warning since flattening them is not meaningful.

Source code in src/nemo_safe_synthesizer/utils.py
def flatten(iter) -> Generator:
    """Flatten a possibly nested iterable.

    Strings are yielded as-is (not broken into characters). Dicts are
    yielded whole with a warning since flattening them is not meaningful.
    """
    if isinstance(iter, dict):
        logger.warning("Flattening a dictionary is not supported. Returning the dictionary as is.")
        yield iter
        return
    for v in iter:
        if is_iterable(v) and not isinstance(v, str):
            yield from flatten(v)
        else:
            yield v

all_equal_type(iter, type_, flatten_iter=True)

Check whether every element in an iterable is an instance of type_.

Parameters:

Name Type Description Default
iter

The iterable to check.

required
type_

The type to check against.

required
flatten_iter

If True, flatten nested iterables before checking.

True
Source code in src/nemo_safe_synthesizer/utils.py
def all_equal_type(iter, type_, flatten_iter=True) -> bool:
    """Check whether every element in an iterable is an instance of ``type_``.

    Args:
        iter: The iterable to check.
        type_: The type to check against.
        flatten_iter: If ``True``, flatten nested iterables before checking.
    """

    def typecheck(x):
        return isinstance(x, type_)

    if flatten_iter:
        mapped = map(typecheck, flatten(iter))
    else:
        mapped = map(typecheck, iter)
    for i in mapped:
        if not i:
            return False
    return True

write_json(data, path, encoding=None, indent=None)

Write a dictionary to a JSON file, creating parent directories as needed.

Source code in src/nemo_safe_synthesizer/utils.py
def write_json(
    data: dict,
    path: str | os.PathLike[str],
    encoding: str | None = None,
    indent: int | None = None,
) -> None:
    """Write a dictionary to a JSON file, creating parent directories as needed."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding=encoding) as file:
        json.dump(data, file, indent=indent)

load_json(path, encoding=None)

Load JSON file and return the content as a dict.

Source code in src/nemo_safe_synthesizer/utils.py
def load_json(path: str | Path, encoding: str | None = None) -> dict:
    """Load JSON file and return the content as a dict."""
    with Path(path).open(encoding=encoding) as file:
        return json.load(file)