Skip to content

environment

environment

Environment-stage checks: GPU, VRAM, tokens, log settings.

Classes:

Name Description
CUDAAvailabilityCheck

Validate CUDA GPU availability.

VRAMComponentEstimate

Per-device training VRAM components for gpu.vram preflight.

VRAMHeadroomCheck

Estimate whether GPU VRAM is sufficient for training.

InferenceModelCheck

Validate the inference configuration used for PII column classification.

HFModelAvailabilityCheck

Validate local model, HF cache, and online HF access readiness.

Functions:

Name Description
param_count_from_empty_model

Count parameters by instantiating the model on the meta device.

estimate_params_from_shape

Shape-only fallback param count used when meta-tensor construction fails.

estimate_base_model_params

Return (n_params, method) for the base model, or None if unknown.

bytes_per_base_weight

Return expected bytes/param for the base model load mode.

estimate_training_vram_components

Compose base weights, overhead, and optional activation estimate (GiB).

CUDAAvailabilityCheck

Bases: ConfigCheck

Validate CUDA GPU availability.

VRAMComponentEstimate(base_weights_gib, overhead_gib, activation_gib, total_gib) dataclass

Per-device training VRAM components for gpu.vram preflight.

VRAMHeadroomCheck

Bases: MetadataCheck

Estimate whether GPU VRAM is sufficient for training.

The estimate is intentionally conservative/heuristic, not worst-case-accurate.

Parameter counts come from estimate_base_model_params via meta tensors when possible.

Activation memory uses estimate_training_vram_components when metadata.max_seq_length and transformer shape fields resolve to positive integers; missing inputs leave activations unspecified and revert to a legacy lumped overhead. Per-device VRAM compares against get_max_vram(max_vram_fraction=training.max_vram_fraction) headroom.

LoRA adapters, full optimizer footprint, \(O(B S^2)\) attention material, and quantization workspace are partially covered only by residual overhead -- passing does not guarantee a fit; failing is a strong signal of OOM risk.

References

InferenceModelCheck

Bases: ConfigCheck

Validate the inference configuration used for PII column classification.

When classification is enabled, the runtime calls an OpenAI-compatible inference endpoint configured by NSS_INFERENCE_KEY, NSS_INFERENCE_MODEL, and NSS_INFERENCE_ENDPOINT (set directly or via the matching CLI flags, which are propagated to the environment before preflight runs). This check reads those env vars -- not config -- because the inference settings live in CLISettings/the environment rather than in SafeSynthesizerParameters.

The body uses a single-dispatch match over (model, key, endpoint), so at most one finding is emitted per run -- the highest-priority problem. Priority order: invalid endpoint, then missing key, then blank model id. The invalid endpoint is an error (a non-http(s) endpoint cannot succeed, so the run must not pass --validate); the key and model findings are warnings (classification degrades or falls back rather than failing the run). The error is checked first so a lower-severity warning never masks it.

HFModelAvailabilityCheck

Bases: ConfigCheck

Validate local model, HF cache, and online HF access readiness.

param_count_from_empty_model(autoconfig)

Count parameters by instantiating the model on the meta device.

accelerate.init_empty_weights constructs the full nn.Module graph with every parameter on torch.device("meta") -- no storage is allocated and no weights are downloaded. AutoModelForCausalLM.from_config consults the transformers model-class registry to pick the right architecture (handling Nemotron's non-gated MLP, MoE experts, biases, tied embeddings, and any future variant automatically).

Returns None if accelerate/transformers are missing, the config doesn't map to a registered architecture (e.g. trust_remote_code custom archs), or instantiation fails for any other reason. The caller should fall back to estimate_params_from_shape.

References
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def param_count_from_empty_model(autoconfig: PretrainedConfig) -> int | None:
    """Count parameters by instantiating the model on the ``meta`` device.

    ``accelerate.init_empty_weights`` constructs the full ``nn.Module`` graph
    with every parameter on ``torch.device("meta")`` -- no storage is
    allocated and no weights are downloaded. ``AutoModelForCausalLM.from_config``
    consults the transformers model-class registry to pick the right
    architecture (handling Nemotron's non-gated MLP, MoE experts, biases,
    tied embeddings, and any future variant automatically).

    Returns ``None`` if accelerate/transformers are missing, the config
    doesn't map to a registered architecture (e.g. ``trust_remote_code``
    custom archs), or instantiation fails for any other reason. The caller
    should fall back to
    [estimate_params_from_shape][nemo_safe_synthesizer.preflight.checks.environment.estimate_params_from_shape].

    References:
        - HuggingFace accelerate, "Big Model Inference" --
          <https://huggingface.co/docs/accelerate/concept_guides/big_model_inference>
        - HuggingFace accelerate, "Model memory estimator" -- same
          meta-device technique exposed as ``accelerate estimate-memory``;
          reported accurate to within a few percent of real CUDA load.
          <https://huggingface.co/docs/accelerate/usage_guides/model_size_estimator>
        - PyTorch meta device --
          <https://docs.pytorch.org/docs/stable/meta.html>
    """
    try:
        from accelerate import init_empty_weights
        from transformers import AutoModelForCausalLM
    except ImportError:
        return None
    try:
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(autoconfig)
        return sum(p.numel() for p in model.parameters())
    except Exception as exc:
        logger.runtime.debug(
            "param_count_from_empty_model failed for %r: %s: %s",
            getattr(autoconfig, "_name_or_path", None) or getattr(autoconfig, "model_type", "unknown"),
            type(exc).__name__,
            exc,
        )
        return None

estimate_params_from_shape(autoconfig)

Shape-only fallback param count used when meta-tensor construction fails.

Models a decoder-only transformer with grouped-query attention (which degrades to multi-head when num_key_value_heads == num_attention_heads) and a gated SwiGLU/GeGLU MLP -- the shape NSS sees on its supported model families (Llama, Qwen, Mistral, SmolLM, Granite, TinyLlama). For non-gated variants (e.g. Nemotron's squared-ReLU MLP) this over-counts MLP params by 50%, which is why the meta-tensor path is preferred.

With hidden size \(H\), intermediate size \(I\), \(L\) layers, vocabulary \(V\), \(n_\text{kv}\) KV heads and per-head dim \(d\), the per-layer cost is

\[ \text{attn} = 2 H^2 + 2 H \, n_\text{kv} \, d, \qquad \text{mlp} = 3 H I \]

(full Q/O projections; K/V shrunk by GQA; gate/up/down for SwiGLU). Total parameters:

\[ N = V H + L (\text{attn} + \text{mlp}) + \begin{cases} 0 & \text{tied embeddings} \\ V H & \text{untied LM head} \end{cases} \]
References
  • Ainslie, J. et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (2023) -- reduced K/V projection shape. https://arxiv.org/abs/2305.13245
  • So, D. R. et al. "Primer: Searching for Efficient Transformers for Language Modeling" (2021) -- squared-ReLU MLP (Nemotron family), 2 projections; motivates the 50% over-count caveat above. https://arxiv.org/abs/2109.08668
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def estimate_params_from_shape(autoconfig: PretrainedConfig) -> int | None:
    r"""Shape-only fallback param count used when meta-tensor construction fails.

    Models a decoder-only transformer with grouped-query attention (which
    degrades to multi-head when ``num_key_value_heads == num_attention_heads``)
    and a gated SwiGLU/GeGLU MLP -- the shape NSS sees on its supported
    model families (Llama, Qwen, Mistral, SmolLM, Granite, TinyLlama). For
    non-gated variants (e.g. Nemotron's squared-ReLU MLP) this over-counts
    MLP params by 50%, which is why the meta-tensor path is preferred.

    With hidden size \(H\), intermediate size \(I\), \(L\) layers,
    vocabulary \(V\), \(n_\text{kv}\) KV heads and per-head dim \(d\), the
    per-layer cost is

    \[
        \text{attn} = 2 H^2 + 2 H \, n_\text{kv} \, d, \qquad
        \text{mlp}  = 3 H I
    \]

    (full Q/O projections; K/V shrunk by GQA; gate/up/down for SwiGLU).
    Total parameters:

    \[
        N = V H + L (\text{attn} + \text{mlp}) + \begin{cases}
            0      & \text{tied embeddings} \\
            V H    & \text{untied LM head}
        \end{cases}
    \]

    References:
        - Ainslie, J. et al. "GQA: Training Generalized Multi-Query
          Transformer Models from Multi-Head Checkpoints" (2023) --
          reduced K/V projection shape. <https://arxiv.org/abs/2305.13245>
        - So, D. R. et al. "Primer: Searching for Efficient Transformers
          for Language Modeling" (2021) -- squared-ReLU MLP (Nemotron
          family), 2 projections; motivates the 50% over-count caveat
          above. <https://arxiv.org/abs/2109.08668>
    """
    H = getattr(autoconfig, "hidden_size", None)
    L = getattr(autoconfig, "num_hidden_layers", None)
    if not (H and L):
        return None
    V = getattr(autoconfig, "vocab_size", 32_000) or 32_000
    inter = getattr(autoconfig, "intermediate_size", None) or 4 * H
    n_heads = getattr(autoconfig, "num_attention_heads", None) or max(1, H // 64)
    kv_heads = getattr(autoconfig, "num_key_value_heads", None) or n_heads
    head_dim = getattr(autoconfig, "head_dim", None) or max(1, H // max(n_heads, 1))
    tied = bool(getattr(autoconfig, "tie_word_embeddings", False))

    # Q proj (H×H) + O proj (H×H) = 2H²; K proj + V proj shrunk by GQA: 2·H·kv_heads·d
    attn = 2 * H * H + 2 * H * kv_heads * head_dim
    # gate proj (H×I) + up proj (H×I) + down proj (I×H) = 3HI  (SwiGLU / GeGLU gated MLP)
    mlp = 3 * H * inter
    per_layer = attn + mlp
    embed = V * H  # token embedding table
    lm_head = 0 if tied else V * H  # unembedding; zero when tied to embed
    return embed + L * per_layer + lm_head

estimate_base_model_params(autoconfig)

Return (n_params, method) for the base model, or None if unknown.

method == "exact" means the meta-tensor path succeeded and the count is architecture-accurate. method == "approximate" means the shape formula was used as a fallback (see estimate_params_from_shape for its known error modes) and the caller should flag the downstream VRAM estimate as heuristic. Benchmarked fallback error on supported architectures: \(-22\%\) to \(+33\%\); hybrid Mamba-Transformer models (e.g. Nemotron-H) can drift further.

Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def estimate_base_model_params(autoconfig: PretrainedConfig) -> tuple[int, Literal["exact", "approximate"]] | None:
    r"""Return ``(n_params, method)`` for the base model, or ``None`` if unknown.

    ``method == "exact"`` means the meta-tensor path succeeded and the count
    is architecture-accurate. ``method == "approximate"`` means the shape
    formula was used as a fallback (see
    [estimate_params_from_shape][nemo_safe_synthesizer.preflight.checks.environment.estimate_params_from_shape]
    for its known error modes) and the caller should flag the downstream VRAM
    estimate as heuristic. Benchmarked fallback error on supported
    architectures: \(-22\%\) to \(+33\%\); hybrid Mamba-Transformer models
    (e.g. Nemotron-H) can drift further.
    """
    exact = param_count_from_empty_model(autoconfig)
    if exact is not None:
        return exact, "exact"
    approx = estimate_params_from_shape(autoconfig)
    if approx is None:
        return None
    return approx, "approximate"

bytes_per_base_weight(training_cfg)

Return expected bytes/param for the base model load mode.

NSS always trains via LoRA-style adapters, so the base model's storage precision dominates VRAM (LoRA adapter params, gradients, and optimizer state are comparatively negligible).

Runtime quantization is controlled by training.quantize_model. The PEFT type string alone is not enough: with quantize_model=False the base weights are loaded as bf16 even when peft_implementation is configured as "QLORA".

  • Quantized load: \(\text{bits}/8 + 0.1\) to cover quant state (absmax / block scales) and dequant workspace. Yields \(\approx 0.6\) for 4-bit, \(\approx 1.1\) for 8-bit.
  • Unquantized load: \(2\) bytes (bf16/fp16 base weights).
References
  • Hu, E. J. et al. "LoRA: Low-Rank Adaptation of Large Language Models" (2021) -- base weights frozen; adapter + gradients + optimizer state are small relative to \(N b\). https://arxiv.org/abs/2106.09685
  • Dettmers, T. et al. "QLoRA: Efficient Finetuning of Quantized LLMs" (2023) -- 4-bit NF4 quantization with block-wise absmax scales; the \(+0.1\) term accounts for these scales and the dequant workspace. https://arxiv.org/abs/2305.14314
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def bytes_per_base_weight(training_cfg: TrainingHyperparams) -> float:
    r"""Return expected bytes/param for the base model load mode.

    NSS always trains via LoRA-style adapters, so the base model's storage
    precision dominates VRAM (LoRA adapter params, gradients, and
    optimizer state are comparatively negligible).

    Runtime quantization is controlled by ``training.quantize_model``. The
    PEFT type string alone is not enough: with ``quantize_model=False`` the
    base weights are loaded as bf16 even when ``peft_implementation`` is
    configured as ``"QLORA"``.

    - Quantized load: \(\text{bits}/8 + 0.1\) to cover quant state (absmax /
      block scales) and dequant workspace. Yields \(\approx 0.6\) for 4-bit,
      \(\approx 1.1\) for 8-bit.
    - Unquantized load: \(2\) bytes (bf16/fp16 base weights).

    References:
        - Hu, E. J. et al. "LoRA: Low-Rank Adaptation of Large Language
          Models" (2021) -- base weights frozen; adapter + gradients +
          optimizer state are small relative to \(N b\).
          <https://arxiv.org/abs/2106.09685>
        - Dettmers, T. et al. "QLoRA: Efficient Finetuning of Quantized
          LLMs" (2023) -- 4-bit NF4 quantization with block-wise absmax
          scales; the \(+0.1\) term accounts for these scales and the
          dequant workspace. <https://arxiv.org/abs/2305.14314>
    """
    if training_cfg.quantize_model:
        # Prefer the explicit scheme if set; otherwise fall back to the legacy
        # bits-based field. Both routes yield bits/param for memory estimation.
        if training_cfg.quantization_scheme is not None:
            bits = training_cfg.quantization_scheme.effective_bits
        else:
            bits = training_cfg.quantization_bits
        return bits / 8 + 0.1
    return 2.0

activation_memory_gib(*, batch_size, seq_len, hidden_size, num_hidden_layers, bytes_per_activation_element=2.0)

Rough activation VRAM on one device given micro-batch geometry.

Uses training.batch_size (HF per_device_train_batch_size), not gradient_accumulation_steps. Matches bf16-ish training tensors at 2 bytes/element:

\[ M_\text{act} \approx B \cdot S \cdot H \cdot L \cdot 2\text{ bytes} \]

Omit attention \(O(B S^2)\) blocks and recomputation specifics; goal is order-of-magnitude headroom versus absurd batch_size values.

References
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def activation_memory_gib(
    *,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    num_hidden_layers: int,
    bytes_per_activation_element: float = 2.0,
) -> float:
    r"""Rough activation VRAM on one device given micro-batch geometry.

    Uses ``training.batch_size`` (HF ``per_device_train_batch_size``), not
    ``gradient_accumulation_steps``. Matches bf16-ish training tensors at
    2 bytes/element:

    \[
        M_\text{act} \approx B \cdot S \cdot H \cdot L \cdot 2\text{ bytes}
    \]

    Omit attention \(O(B S^2)\) blocks and recomputation specifics; goal is
    order-of-magnitude headroom versus absurd ``batch_size`` values.

    References:
        - Korthikanti, V. et al. (2022) -- recomputation vs stored activations.
          <https://arxiv.org/abs/2205.05198>
    """
    nbytes = batch_size * seq_len * hidden_size * num_hidden_layers * bytes_per_activation_element
    return nbytes / (1024**3)

estimate_training_vram_components(*, n_params, training_cfg, batch_size, seq_len, hidden_size, num_hidden_layers, bytes_per_activation_element=2.0)

Compose base weights, overhead, and optional activation estimate (GiB).

Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def estimate_training_vram_components(
    *,
    n_params: int,
    training_cfg: TrainingHyperparams,
    batch_size: int,
    seq_len: int | None,
    hidden_size: int | None,
    num_hidden_layers: int | None,
    bytes_per_activation_element: float = 2.0,
) -> VRAMComponentEstimate:
    """Compose base weights, overhead, and optional activation estimate (GiB)."""
    bpw = bytes_per_base_weight(training_cfg)
    base_weights_gib = (n_params * bpw) / (1024**3)

    b_sz = _positive_int_scalar(batch_size)
    seq = _positive_int_scalar(seq_len) if seq_len is not None else None
    h_sz = _positive_int_scalar(hidden_size) if hidden_size is not None else None
    n_layers = _positive_int_scalar(num_hidden_layers) if num_hidden_layers is not None else None

    activation_gib: float | None
    overhead_gib: float
    if b_sz is not None and seq is not None and h_sz is not None and n_layers is not None:
        activation_gib = activation_memory_gib(
            batch_size=b_sz,
            seq_len=seq,
            hidden_size=h_sz,
            num_hidden_layers=n_layers,
            bytes_per_activation_element=bytes_per_activation_element,
        )
        overhead_gib = _VRAM_KERNEL_RESERVED_GIB
    else:
        activation_gib = None
        overhead_gib = _VRAM_LEGACY_OVERHEAD_GIB

    total_gib = base_weights_gib + overhead_gib + (activation_gib if activation_gib is not None else 0.0)

    return VRAMComponentEstimate(
        base_weights_gib=base_weights_gib,
        overhead_gib=overhead_gib,
        activation_gib=activation_gib,
        total_gib=total_gib,
    )