Skip to content

callbacks

callbacks

HuggingFace Trainer callbacks for Safe Synthesizer training.

Classes:

Name Description
InferenceEvalCallback

Trainer callback that performs inference-based evaluation during training.

ProgressBarCallback

A TrainerCallback that displays the progress of training or evaluation.

SafeSynthesizerWorkerCallback

Trainer callback that emits structured progress logs at a fixed interval.

InferenceEvalCallback(schema, metadata, processor, num_prompts_per_batch=16, num_batches=None, patience=3, invalid_fraction_threshold=0.8, generate_kwargs=None)

Bases: TrainerCallback

Trainer callback that performs inference-based evaluation during training.

Generates records using the current model and validates them against a schema. Empirically, the fraction of invalid records generated is a good indicator of model quality. The callback can stop training early if the invalid fraction satisfies the stopping criteria specified by invalid_fraction_threshold and patience.

Parameters:

Name Type Description Default
schema dict

Schema to validate the generated records against.

required
metadata ModelMetadata

Pretrained model metadata (prompt template, instruction, etc.).

required
processor Processor

Record processor used to parse and validate generated text.

required
num_prompts_per_batch int

Number of prompts per batch.

16
num_batches Optional[int]

Number of batches to generate.

None
invalid_fraction_threshold float

The fraction of invalid records that will stop generation after the patience limit is reached.

0.8
patience int

Number of consecutive generations where the invalid_fraction_threshold is reached before stopping.

3
generate_kwargs dict | None

Keyword arguments to pass to the model's generate method.

None

Methods:

Name Description
on_evaluate

Generate records with the current model and optionally stop training.

Source code in src/nemo_safe_synthesizer/training/callbacks.py
def __init__(
    self,
    schema: dict,
    metadata: ModelMetadata,
    processor: Processor,
    num_prompts_per_batch: int = 16,
    num_batches: Optional[int] = None,
    patience: int = 3,
    invalid_fraction_threshold: float = 0.8,
    generate_kwargs: dict | None = None,
):
    self.schema = schema
    self.metadata = metadata
    self.templated_prompt = create_schema_prompt(
        schema["properties"].keys(),
        instruction=self.metadata.instruction,
        prompt_template=self.metadata.prompt_config.template,
    )
    self.num_prompts_per_batch = num_prompts_per_batch

    self.is_tabular_processor = isinstance(processor, TabularDataProcessor)
    self.num_batches = num_batches or (
        NUM_EVAL_BATCHES_TABULAR if self.is_tabular_processor else NUM_EVAL_BATCHES_GROUPED
    )

    self.generation = GenerationBatches(
        invalid_fraction_threshold=invalid_fraction_threshold,
        patience=patience,
    )
    self.processor = processor

    kws = generate_kwargs or {}
    self.generate_kwargs = {
        "temperature": kws.get("temperature", DEFAULT_SAMPLING_PARAMETERS["temperature"]),
        "top_p": kws.get("top_p", DEFAULT_SAMPLING_PARAMETERS["top_p"]),
        "top_k": kws.get("top_k", DEFAULT_SAMPLING_PARAMETERS["top_k"]),
        "repetition_penalty": kws.get("repetition_penalty", DEFAULT_SAMPLING_PARAMETERS["repetition_penalty"]),
    }

on_evaluate(args, state, control, **kwargs)

Generate records with the current model and optionally stop training.

Runs inference for num_batches batches, validates each against the schema, and sets control.should_training_stop if the invalid record fraction exceeds the threshold for patience consecutive evaluations.

Source code in src/nemo_safe_synthesizer/training/callbacks.py
def on_evaluate(
    self,
    args: TrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    **kwargs,
) -> None:
    """Generate records with the current model and optionally stop training.

    Runs inference for ``num_batches`` batches, validates each against
    the schema, and sets ``control.should_training_stop`` if the invalid
    record fraction exceeds the threshold for ``patience`` consecutive
    evaluations.
    """
    if not state.is_world_process_zero:
        return

    model = kwargs["model"]
    tokenizer = kwargs["tokenizer"]

    with optimize_for_inference(model):
        was_stopped = False

        logger.info(
            f"🔮 Starting inference-based evaluation with the '{self.processor.name}'",
        )

        for _ in range(self.num_batches):
            prompt_tokens = tokenizer(
                [self.templated_prompt] * self.num_prompts_per_batch,
                return_tensors="pt",
            )
            input_ids = prompt_tokens["input_ids"].to(model.device)
            attention_mask = prompt_tokens["attention_mask"].to(model.device)

            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=tokenizer.model_max_length - len(input_ids[0]),
                do_sample=True,
                use_cache=True,
                **self.generate_kwargs,
            )
            decoded = tokenizer.batch_decode(
                outputs,
                skip_special_tokens=self.is_tabular_processor,
            )

            start_time = time.perf_counter()
            batch = Batch(processor=self.processor)
            for idx, text in enumerate(decoded):
                batch.process(idx, text)
            duration = time.perf_counter() - start_time
            self.generation.add_batch(batch)

            batch.log_summary()
            duration_string = f"{duration:.1f} seconds" if duration < 120 else f"{duration / 60:.1f} minutes"
            logger.info(f"Generation time: {duration_string}")

            if self.generation.status != GenerationStatus.IN_PROGRESS:
                was_stopped = True
                break

        if was_stopped:
            control.should_training_stop = True
            if self.generation.status == GenerationStatus.STOP_NO_RECORDS:
                logger.error(
                    "🛑 Stopping generation prematurely. No records were generated. "
                    "Please consider adjusting the sampling parameters.",
                )
                state.log_history.append({"training_incomplete": "no_records"})
            elif self.generation.status == GenerationStatus.STOP_METRIC_REACHED:
                logger.error(
                    "🛑 Stopping generation prematurely. The stopping "
                    "condition was reached with a running average invalid "
                    f"fraction of {self.generation.stop_condition.last_value:.2%}",
                )
                state.log_history.append({"training_incomplete": "stopping_condition_reached"})

ProgressBarCallback()

Bases: TrainerCallback

A TrainerCallback that displays the progress of training or evaluation.

Note

This callback can only be used during development.

Source code in src/nemo_safe_synthesizer/training/callbacks.py
def __init__(self):
    self.training_bar = None
    self.prediction_bar = None

SafeSynthesizerWorkerCallback(log_interval=60.0)

Bases: TrainerCallback

Trainer callback that emits structured progress logs at a fixed interval.

Logs are written via the logger.runtime channel as rendered tables containing epoch, step, loss, and progress fraction.

Parameters:

Name Type Description Default
log_interval float

Minimum seconds between successive log emissions.

60.0
Source code in src/nemo_safe_synthesizer/training/callbacks.py
def __init__(self, log_interval: float = 60.0):
    self._log_interval = log_interval
    self._last_log_ts = time.monotonic()
    self._last_log_global_step = 0