Skip to content

batch

batch

Single-batch container for generated records and error statistics.

Classes:

Name Description
Batch

Container for the results of a single generation batch.

Batch(processor)

Container for the results of a single generation batch.

Collects ParsedResponse objects produced by the processor and exposes aggregate counts and error statistics.

Parameters:

Name Type Description Default
processor Processor

The processor used to parse LLM outputs into records.

required

Methods:

Name Description
error_statistics

Return count statistics on errors encountered during generation.

to_dataframe

Return the valid records as a normalized DataFrame.

log_summary

Log a summary of the batch generation results.

process

Process text response from a single prompt in the current batch.

Attributes:

Name Type Description
num_prompts int

Total number of prompts submitted in this batch.

num_invalid_records int

Number of invalid records generated in this batch.

num_valid_records int

Number of valid records generated in this batch.

data_config_rejected_records list[tuple[str, str]]

Error tuples for records rejected by data_config validation.

num_data_config_rejected_records int

Count of records rejected by data_config validation.

valid_record_fraction float

Fraction of generated records that passed validation.

stopping_metric float

Invalid record fraction, used by

Source code in src/nemo_safe_synthesizer/generation/batch.py
def __init__(self, processor: Processor):
    self._responses: list[ParsedResponse] = []
    self._processor = processor

num_prompts property

Total number of prompts submitted in this batch.

num_invalid_records property

Number of invalid records generated in this batch.

num_valid_records property

Number of valid records generated in this batch.

data_config_rejected_records property

Error tuples for records rejected by data_config validation.

num_data_config_rejected_records property

Count of records rejected by data_config validation.

valid_record_fraction property

Fraction of generated records that passed validation.

stopping_metric property

Invalid record fraction, used by GenerationStopCondition.

error_statistics(detailed_errors)

Return count statistics on errors encountered during generation.

Parameters:

Name Type Description Default
detailed_errors bool

If True, include expected column names and allowed field values. If False, report only high-level error categories.

required

Returns:

Type Description
DataFrame

DataFrame indexed by error message with a Percentage

DataFrame

column, sorted by frequency descending.

Source code in src/nemo_safe_synthesizer/generation/batch.py
def error_statistics(self, detailed_errors: bool) -> pd.DataFrame:
    """Return count statistics on errors encountered during generation.

    Args:
        detailed_errors: If ``True``, include expected column names
            and allowed field values.  If ``False``, report only
            high-level error categories.

    Returns:
        DataFrame indexed by error message with a ``Percentage``
        column, sorted by frequency descending.
    """
    idx = 0 if detailed_errors else 1
    err_msgs = [e[idx] for resp in self._responses for e in resp.errors]
    # Map error messages to human-readable categories, as necessary
    err_msgs = [HUMAN_READABLE_ERR_MSGS.get(msg, msg) for msg in err_msgs]
    if detailed_errors:
        # Group similar error messages to consolidate the error counts
        common_error_strings = [
            " is not one of ",
            " is a required property",
            " is greater than the maximum of ",
            " is less than the minimum of ",
        ]
        for common_error_string in common_error_strings:
            err_msgs = _group_error_messages(err_msgs, common_error_string)
    err_stats = pd.DataFrame.from_dict(Counter(err_msgs), orient="index", columns=["cnt"])
    err_stats["Percentage"] = err_stats["cnt"] / err_stats["cnt"].sum()
    err_stats = err_stats.drop("cnt", axis=1)
    # sort `index` (error messages) then `Percentage`, so it's deterministic
    err_stats = err_stats.sort_index()
    err_stats = err_stats.sort_values("Percentage", ascending=False, kind="mergesort")
    if LOG_NUM_ERRORS is not None:
        # separate out data_config errors to ensure they don't get truncated

        data_config_errors = err_stats[err_stats.index.astype("str").str.contains("data_config", na=False)]
        other_errors = err_stats[~err_stats.index.astype("str").str.contains("data_config", na=False)]
        other_errors = other_errors.head(LOG_NUM_ERRORS)

        err_stats = pd.concat([data_config_errors, other_errors])
        err_stats = err_stats.sort_values("Percentage", ascending=False)
    return err_stats

to_dataframe()

Return the valid records as a normalized DataFrame.

Returns:

Type Description
DataFrame | None

DataFrame of valid records, or None if no valid records

DataFrame | None

were generated.

Source code in src/nemo_safe_synthesizer/generation/batch.py
def to_dataframe(self) -> pd.DataFrame | None:
    """Return the valid records as a normalized DataFrame.

    Returns:
        DataFrame of valid records, or ``None`` if no valid records
        were generated.
    """
    valid = [resp.valid_records for resp in self._responses]
    flat_records = [record for records in valid for record in records]
    df = pd.DataFrame.from_records(flat_records)
    return None if df.empty else normalize_dataframe(df)

log_summary(detailed_errors=False)

Log a summary of the batch generation results.

Emits structured data via logger.user.info that is rendered as Rich ASCII tables on the console and as key/value pairs in JSON logs.

Parameters:

Name Type Description Default
detailed_errors bool

If True, include per-column error statistics in the log output.

False
Source code in src/nemo_safe_synthesizer/generation/batch.py
def log_summary(self, detailed_errors: bool = False) -> None:
    """Log a summary of the batch generation results.

    Emits structured data via ``logger.user.info`` that is rendered
    as Rich ASCII tables on the console and as key/value pairs in
    JSON logs.

    Args:
        detailed_errors: If ``True``, include per-column error
            statistics in the log output.
    """
    err_stats: pd.DataFrame = self.error_statistics(detailed_errors=detailed_errors)

    # Build structured summary data - processor renders as table for console
    summary_data = {
        "num_prompts": self.num_prompts,
        "num_valid_records": self.num_valid_records,
        "num_invalid_records": self.num_invalid_records,
        "valid_record_fraction": round(self.valid_record_fraction, 2),
    }
    if self.num_data_config_rejected_records:
        summary_data["num_data_config_rejected_records"] = self.num_data_config_rejected_records

    # Pass structured data - processor renders for console, JSON keeps as-is
    logger.user.info(
        "",
        extra={
            "ctx": {
                "render_table": True,
                "tabular_data": summary_data,
                "title": "Batch Generation Summary",
            }
        },
    )

    # Log error statistics if present
    if not err_stats.empty:
        # Build structured error data - processor renders as table for console
        error_data = {str(cat): round(row["Percentage"], 2) for cat, row in err_stats.iterrows()}
        logger.user.info(
            "",
            extra={
                "ctx": {
                    "render_table": True,
                    "tabular_data": error_data,
                    "title": "Error Statistics",
                }
            },
        )

process(prompt_number, text)

Process text response from a single prompt in the current batch.

Parameters:

Name Type Description Default
prompt_number int

The prompt number in the current batch.

required
text str

Text generated by the fine-tuned model.

required
Source code in src/nemo_safe_synthesizer/generation/batch.py
def process(self, prompt_number: int, text: str) -> None:
    """Process text response from a single prompt in the current batch.

    Args:
        prompt_number: The prompt number in the current batch.
        text: Text generated by the fine-tuned model.
    """
    self._responses.append(self._processor(prompt_number, text))