Skip to content

huggingface_backend

huggingface_backend

HuggingFace Trainer backend for LoRA fine-tuning.

Classes:

Name Description
HuggingFaceBackend

Training backend built on the HuggingFace Trainer.

Functions:

Name Description
preprocess_logits_for_metrics

Reduce logits to argmax predictions to avoid OOM during evaluation.

compute_metrics

Compute evaluation metrics from forward-pass losses.

HuggingFaceBackend(*args, **kwargs)

Bases: TrainingBackend

Training backend built on the HuggingFace Trainer.

Handles model loading (AutoModelForCausalLM), LoRA/QLoRA wrapping, RoPE scaling, optional differential-privacy training via OpacusDPTrainer, and artifact persistence (adapter, schema, metadata).

Methods:

Name Description
prepare_config

Set common model arguments for initializing a model.

maybe_quantize

Apply LoRA wrapping (and optional k-bit quantization) to the model.

load_model

Load an AutoModelForCausalLM instance with specified arguments.

prepare_params

Prepare training parameters and create the trainer.

prepare_training_data

Validate, preprocess, and tokenize the training dataset.

train

Run the full training pipeline and populate results.

save_model

Save the fine-tuning adapter and related artifacts to the given path.

delete_trainable_model

Delete the trainable model, trainer, and clean up GPU memory and distributed resources.

info

Print a summary of key trainer attributes to stdout.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.trainer_type: type[Trainer] | partial[OpacusDPTrainer] = Trainer
    self.model_loader_type = AutoModelForCausalLM
    self.training_output_dir = Path(self.workdir.train.cache)
    self.autoconfig = AutoConfig.from_pretrained(
        self.params.training.pretrained_model, trust_remote_code=self._trust_remote_code_for_model()
    )

prepare_config(add_max_memory=True, **kwargs)

Set common model arguments for initializing a model.

Parameters:

Name Type Description Default
add_max_memory bool

Whether to add max_memory to the model arguments.

True
**kwargs

Additional keyword arguments, overriding default arguments when set.

{}
Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
@traced_runtime("prepare_config")
def prepare_config(self, add_max_memory: bool = True, **kwargs):
    """Set common model arguments for initializing a model.

    Args:
        add_max_memory: Whether to add max_memory to the model arguments.
        **kwargs: Additional keyword arguments, overriding default arguments when set.
    """
    if self.framework_load_params:
        logger.info("already prepared loading parameters")
        return

    logger.info(f"preparing parameters for HF Automodel with model: {self.params.training.pretrained_model}")

    model_kwargs = self._filter_model_kwargs(kwargs)

    if add_max_memory:
        model_kwargs["max_memory"] = get_max_vram(max_vram_fraction=model_kwargs.pop("max_vram_fraction", None))

    framework_params = self._build_base_framework_params(model_kwargs)
    quant_config = self._get_quantization_config_if_enabled()
    if quant_config is not None:
        framework_params["quantization_config"] = quant_config

    self._apply_rope_scaling(framework_params=framework_params, **kwargs)
    self.framework_load_params = framework_params

maybe_quantize(**quant_params)

Apply LoRA wrapping (and optional k-bit quantization) to the model.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def maybe_quantize(self, **quant_params: dict):
    """Apply LoRA wrapping (and optional k-bit quantization) to the model."""
    self._prepare_quantize_base(**quant_params)
    lora_config = LoraConfig(**self.quant_params)
    if not self.params.training.quantize_model:
        self.model.gradient_checkpointing_enable()
        # see https://discuss.huggingface.co/t/i-used-to-have-no-problem-with-peft-fine-tuning-after-hundreds-of-trainings-but-now-i-have-encountered-the-error-runtimeerror-element-0-of-tensors-does-not-require-grad-and-does-not-have-a-grad-fn/168829/3
        self.model.enable_input_require_grads()  # critical with PEFT + checkpointing
        self.model.config.use_cache = False  # cache off during training
    else:
        self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=True)

    if not isinstance(self.model, PreTrainedModel):
        raise TypeError(f"Expected PreTrainedModel, got {type(self.model)}")
    peft_model = get_peft_model_hf(self.model, peft_config=lora_config)
    self.model = peft_model  # ty: ignore[invalid-assignment]  -- PeftMixedModel not in union, but LoraConfig always yields PeftModel
    parameter_count = get_model_param_count(self.model, trainable_only=True) / 1e6
    logger.info(
        f"Using PEFT - {parameter_count:.2f} million parameters are trainable",
    )

load_model(**model_args)

Load an AutoModelForCausalLM instance with specified arguments.

Parameters:

Name Type Description Default
**model_args

Additional keyword arguments for model configuration, passed directly to AutoModelForCausalLM.from_pretrained().

{}
Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def load_model(self, **model_args):
    """Load an ``AutoModelForCausalLM`` instance with specified arguments.

    Args:
        **model_args: Additional keyword arguments for model configuration,
            passed directly to ``AutoModelForCausalLM.from_pretrained()``.
    """
    logger.info(f"loading pretrained model: {self.params.training.pretrained_model}")
    self.prepare_config(**model_args)
    self._load_pretrained_model(**model_args)
    self.maybe_quantize(**model_args)

prepare_params(**training_args)

Prepare training parameters and create the trainer.

Parameters:

Name Type Description Default
**training_args

Additional training arguments (currently unused but kept for API compatibility).

{}
Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
@traced_runtime("prepare_params")
def prepare_params(self, **training_args):
    """Prepare training parameters and create the trainer.

    Args:
        **training_args: Additional training arguments (currently unused but kept for API compatibility).
    """
    if not hasattr(self, "model"):
        self.load_model()

    training_args = self._build_base_training_args()
    self._apply_eval_dataset_overrides(training_args)

    if self.params.privacy is not None and self.params.privacy.dp_enabled:
        data_collator = self._configure_dp_training(training_args)
    else:
        data_collator = self._configure_standard_training(training_args)

    # Enable W&B logging if a WANDB run is initialized
    training_args["report_to"] = "wandb" if wandb.run is not None else "none"
    self.train_args = TrainingArguments(**training_args)
    self.trainer = self._create_trainer(self.train_args, data_collator)
    self._configure_trainer_callbacks(self.trainer, training_args)

prepare_training_data()

Validate, preprocess, and tokenize the training dataset.

Runs auto-config resolution, time-series processing, groupby / orderby validation, and assembles tokenized training examples. Populates training_examples, dataset_schema, df_train, and data_fraction.

Raises:

Type Description
DataError

If the training dataset is missing or malformed.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def prepare_training_data(self):
    """Validate, preprocess, and tokenize the training dataset.

    Runs auto-config resolution, time-series processing, groupby /
    orderby validation, and assembles tokenized training examples.
    Populates ``training_examples``, ``dataset_schema``,
    ``df_train``, and ``data_fraction``.

    Raises:
        DataError: If the training dataset is missing or malformed.
    """
    logger.info("Preparing training data.")

    if self.training_dataset is None:
        raise DataError("training_dataset must be set before preparing training data")

    df_all = self.training_dataset.to_pandas()
    if not isinstance(df_all, pd.DataFrame):
        raise DataError("Expected DataFrame from to_pandas(), got an iterator")

    self.params = AutoConfigResolver(df_all, self.params).resolve()

    # Validate groupby/orderby parameters as a preprocessing step.
    self._validate_groupby_column(df_all)
    self._validate_orderby_column(df_all)

    # Process time series data (sort by timestamp, infer intervals, etc.)
    df_all = self._process_timeseries(df_all)

    df_train = self._apply_preprocessing(df_all)
    df_test = None

    hf_dataset = Dataset.from_pandas(df_train, preserve_index=False)
    # Exclude PSEUDO_GROUP_COLUMN from schema (internal column for ungrouped time series)
    schema_df = df_train.drop(columns=[PSEUDO_GROUP_COLUMN], errors="ignore")
    self.dataset_schema = make_json_schema(schema_df)
    self.df_train = df_train
    self.df_test = df_test

    assembler = self._create_example_assembler(hf_dataset)

    # This is a proxy for the number of training steps.
    self.data_fraction = self.params.training.num_input_records_to_sample / assembler.num_records_train

    self._log_dataset_statistics(assembler)

    self.training_examples = assembler.assemble_training_examples(data_fraction=self.data_fraction)

    logger.user.info(
        f"Number of training examples: {len(self.training_examples.train)}",
    )

    # This info is needed inside the trainer for DP
    # Number of records, if group_training_examples_by is None, or else number of groups
    self.true_dataset_size = len(assembler.train_dataset)

    if self.params.time_series.is_timeseries:
        self.model_metadata.initial_prefill = assembler._get_initial_prefill()

train(**training_args)

Run the full training pipeline and populate results.

Sequentially calls prepare_training_data, prepare_params, trains the model, and saves artifacts.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
@utils.time_function
def train(self, **training_args):
    """Run the full training pipeline and populate ``results``.

    Sequentially calls ``prepare_training_data``,
    ``prepare_params``, trains the model, and saves artifacts.
    """
    training_start = time.monotonic()
    self.prepare_training_data()
    self.prepare_params(**training_args)
    self.trainer.train()
    training_time_sec = time.monotonic() - training_start

    # Save log_history before save_model() which may delete the trainer
    log_history = self.trainer.state.log_history
    is_complete = "training_incomplete" not in sum([list(d.keys()) for d in log_history], [])

    self.save_model()

    self.results = NSSTrainerResult(
        df_train=self.df_train,
        df_ml_utility_holdout=self.df_test,
        config=self.params,
        training_complete=is_complete,
        log_history=log_history,
        adapter_path=self.model_metadata.adapter_path,
        elapsed_time=training_time_sec,
    )

save_model(delete_trainable_model=True)

Save the fine-tuning adapter and related artifacts to the given path.

Parameters:

Name Type Description Default
delete_trainable_model bool

If True, delete the model from memory after saving.

True
Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def save_model(self, delete_trainable_model: bool = True) -> None:
    """Save the fine-tuning adapter and related artifacts to the given path.

    Args:
        delete_trainable_model: If True, delete the model from memory after saving.
    """
    if self.dataset_schema is None:
        raise ParameterError("dataset_schema must be set before saving model")

    adapter_dir = self.workdir.train.adapter
    if not isinstance(adapter_dir, BoundDir):
        raise TypeError(f"Expected BoundDir, got {type(adapter_dir)}")
    self.workdir.ensure_directories()
    logger.user.info(f"Saving LoRA adapter to {adapter_dir}")
    with redirect_stdout(io.StringIO()) as stdout:
        self.model.save_pretrained(str(adapter_dir))
    logger.runtime.debug(stdout.getvalue())
    logger.user.info(f"Saving model metadata to {adapter_dir.metadata}")
    self.model_metadata.save_metadata()
    logger.user.info(f"Saving dataset schema to {adapter_dir.schema}")
    write_json(self.dataset_schema, adapter_dir.schema, indent=4)
    logger.user.info(f"Saving model parameters to {self.workdir.train.config}")
    write_json(
        self.params.model_dump(mode="json"),
        path=self.workdir.train.config,
        indent=4,
    )
    if delete_trainable_model:
        self.delete_trainable_model()

delete_trainable_model()

Delete the trainable model, trainer, and clean up GPU memory and distributed resources.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def delete_trainable_model(self) -> None:
    """Delete the trainable model, trainer, and clean up GPU memory and distributed resources."""
    import torch.distributed as dist

    # Delete the trainer first, as it holds references to the model
    if hasattr(self, "trainer"):
        del self.trainer
    if hasattr(self, "model"):
        del self.model
    cleanup_memory()
    # Clean up distributed process group if it was initialized by the Trainer

    if TYPE_CHECKING:
        assert hasattr(dist, "destroy_process_group")
        assert hasattr(dist, "is_initialized")
    if dist.is_initialized():
        dist.destroy_process_group()

info()

Print a summary of key trainer attributes to stdout.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def info(self):
    """Print a summary of key trainer attributes to stdout."""
    fields = [
        "params",
        "training_output_dir",
        "save_path",
        "artifact_path",
    ]
    info = {field: getattr(self, field) for field in fields}
    msg = "Trainer Information"
    msg += "\n" + "-" * len(msg)
    msg += "\n" + "\n".join([f"{field}: {value}" for field, value in info.items()])
    msg += "\n" + "-" * len(msg)

    logger.info(msg)

preprocess_logits_for_metrics(logits, labels)

Reduce logits to argmax predictions to avoid OOM during evaluation.

The default Trainer stores full logit tensors across evaluation batches, which can exhaust GPU memory on large datasets. This callback replaces them with predicted token IDs immediately after the forward pass.

See: https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/13

Parameters:

Name Type Description Default
logits tuple[Tensor, ...]

Tuple of logits tensors from the model output.

required
labels Tensor

Ground truth labels tensor.

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (predicted_token_ids, labels).

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def preprocess_logits_for_metrics(
    logits: tuple[torch.Tensor, ...], labels: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Reduce logits to argmax predictions to avoid OOM during evaluation.

    The default Trainer stores full logit tensors across evaluation batches,
    which can exhaust GPU memory on large datasets. This callback replaces
    them with predicted token IDs immediately after the forward pass.

    See: https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/13

    Args:
        logits: Tuple of logits tensors from the model output.
        labels: Ground truth labels tensor.

    Returns:
        Tuple of ``(predicted_token_ids, labels)``.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

compute_metrics(eval_preds)

Compute evaluation metrics from forward-pass losses.

Metrics returned:

  • mean cross-entropy loss (eval_loss) -- average of per-batch losses collected during the evaluation loop.

The per-batch losses are pre-computed during the forward pass (via include_for_metrics).

Parameters:

Name Type Description Default
eval_preds EvalPrediction

Evaluation predictions object whose losses field contains per-batch losses collected during the eval loop.

required

Returns:

Type Description
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/nemo_safe_synthesizer/training/huggingface_backend.py
def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
    """Compute evaluation metrics from forward-pass losses.

    Metrics returned:

    - mean cross-entropy loss (``eval_loss``) -- average of per-batch
      losses collected during the evaluation loop.

    The per-batch losses are pre-computed during the forward pass
    (via ``include_for_metrics``).

    Args:
        eval_preds: Evaluation predictions object whose ``losses`` field
            contains per-batch losses collected during the eval loop.

    Returns:
        Dictionary mapping metric names to values.
    """
    # include_for_metrics has "loss", so the loss is already computed in the forward pass
    losses = eval_preds.losses if eval_preds.losses is not None else []
    metrics = {"eval_loss": np.mean(losses)}

    # Log the evaluation loss using the same style as callbacks.py
    if metrics["eval_loss"] is not None:
        logger.user.info(
            f"Evaluation loss: {metrics['eval_loss']:.4f}",
            extra={
                "ctx": {"eval_loss": round(metrics["eval_loss"], 4)},
            },
        )

    return metrics