Skip to content

environment

environment

Environment-stage checks: GPU, VRAM, tokens, log settings.

Classes:

Name Description
CUDAAvailabilityCheck

Validate CUDA GPU availability.

VRAMHeadroomCheck

Estimate whether GPU VRAM is sufficient for training.

InferenceKeyCheck

Check NSS_INFERENCE_KEY environment variable.

HFModelAvailabilityCheck

Validate local model, HF cache, and online HF access readiness.

Functions:

Name Description
param_count_from_empty_model

Count parameters by instantiating the model on the meta device.

estimate_params_from_shape

Shape-only fallback param count used when meta-tensor construction fails.

estimate_base_model_params

Return (n_params, method) for the base model, or None if unknown.

bytes_per_base_weight

Return expected bytes/param for the base model given PEFT mode.

CUDAAvailabilityCheck

Bases: ConfigCheck

Validate CUDA GPU availability.

VRAMHeadroomCheck

Bases: MetadataCheck

Estimate whether GPU VRAM is sufficient for training.

The estimate is intentionally a lower bound:

\[ \text{VRAM}_\text{est} = N \cdot b + C \]

where \(N\) is the base-model parameter count (see estimate_base_model_params; exact via the meta-tensor path, or the shape-heuristic fallback), \(b\) is the bytes-per-param for the selected PEFT mode (see bytes_per_base_weight), and \(C\) is a fixed overhead for CUDA kernels and checkpointed activations. The expression excludes the fine-grained activation term \(\mathcal{O}(B \cdot S \cdot H \cdot L)\), LoRA adapter parameters, gradients, and optimizer state. Those are typically small compared to the base weights for parameter-efficient fine-tuning, but not zero. Passing this check does not guarantee training will fit in VRAM; failing it is a strong signal that it will OOM.

References

InferenceKeyCheck

Bases: ConfigCheck

Check NSS_INFERENCE_KEY environment variable.

HFModelAvailabilityCheck

Bases: ConfigCheck

Validate local model, HF cache, and online HF access readiness.

param_count_from_empty_model(autoconfig)

Count parameters by instantiating the model on the meta device.

accelerate.init_empty_weights constructs the full nn.Module graph with every parameter on torch.device("meta") -- no storage is allocated and no weights are downloaded. AutoModelForCausalLM.from_config consults the transformers model-class registry to pick the right architecture (handling Nemotron's non-gated MLP, MoE experts, biases, tied embeddings, and any future variant automatically).

Returns None if accelerate/transformers are missing, the config doesn't map to a registered architecture (e.g. trust_remote_code custom archs), or instantiation fails for any other reason. The caller should fall back to estimate_params_from_shape.

References
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def param_count_from_empty_model(autoconfig: PretrainedConfig) -> int | None:
    """Count parameters by instantiating the model on the ``meta`` device.

    ``accelerate.init_empty_weights`` constructs the full ``nn.Module`` graph
    with every parameter on ``torch.device("meta")`` -- no storage is
    allocated and no weights are downloaded. ``AutoModelForCausalLM.from_config``
    consults the transformers model-class registry to pick the right
    architecture (handling Nemotron's non-gated MLP, MoE experts, biases,
    tied embeddings, and any future variant automatically).

    Returns ``None`` if accelerate/transformers are missing, the config
    doesn't map to a registered architecture (e.g. ``trust_remote_code``
    custom archs), or instantiation fails for any other reason. The caller
    should fall back to
    [estimate_params_from_shape][nemo_safe_synthesizer.preflight.checks.environment.estimate_params_from_shape].

    References:
        - HuggingFace accelerate, "Big Model Inference" --
          <https://huggingface.co/docs/accelerate/concept_guides/big_model_inference>
        - HuggingFace accelerate, "Model memory estimator" -- same
          meta-device technique exposed as ``accelerate estimate-memory``;
          reported accurate to within a few percent of real CUDA load.
          <https://huggingface.co/docs/accelerate/usage_guides/model_size_estimator>
        - PyTorch meta device --
          <https://docs.pytorch.org/docs/stable/meta.html>
    """
    try:
        from accelerate import init_empty_weights
        from transformers import AutoModelForCausalLM
    except ImportError:
        return None
    try:
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(autoconfig)
        return sum(p.numel() for p in model.parameters())
    except Exception as exc:
        logger.runtime.debug(
            "param_count_from_empty_model failed for %r: %s: %s",
            getattr(autoconfig, "_name_or_path", None) or getattr(autoconfig, "model_type", "unknown"),
            type(exc).__name__,
            exc,
        )
        return None

estimate_params_from_shape(autoconfig)

Shape-only fallback param count used when meta-tensor construction fails.

Models a decoder-only transformer with grouped-query attention (which degrades to multi-head when num_key_value_heads == num_attention_heads) and a gated SwiGLU/GeGLU MLP -- the shape NSS sees on its supported model families (Llama, Qwen, Mistral, SmolLM, Granite, TinyLlama). For non-gated variants (e.g. Nemotron's squared-ReLU MLP) this over-counts MLP params by 50%, which is why the meta-tensor path is preferred.

With hidden size \(H\), intermediate size \(I\), \(L\) layers, vocabulary \(V\), \(n_\text{kv}\) KV heads and per-head dim \(d\), the per-layer cost is

\[ \text{attn} = 2 H^2 + 2 H \, n_\text{kv} \, d, \qquad \text{mlp} = 3 H I \]

(full Q/O projections; K/V shrunk by GQA; gate/up/down for SwiGLU). Total parameters:

\[ N = V H + L (\text{attn} + \text{mlp}) + \begin{cases} 0 & \text{tied embeddings} \\ V H & \text{untied LM head} \end{cases} \]
References
  • Ainslie, J. et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (2023) -- reduced K/V projection shape. https://arxiv.org/abs/2305.13245
  • So, D. R. et al. "Primer: Searching for Efficient Transformers for Language Modeling" (2021) -- squared-ReLU MLP (Nemotron family), 2 projections; motivates the 50% over-count caveat above. https://arxiv.org/abs/2109.08668
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def estimate_params_from_shape(autoconfig: PretrainedConfig) -> int | None:
    r"""Shape-only fallback param count used when meta-tensor construction fails.

    Models a decoder-only transformer with grouped-query attention (which
    degrades to multi-head when ``num_key_value_heads == num_attention_heads``)
    and a gated SwiGLU/GeGLU MLP -- the shape NSS sees on its supported
    model families (Llama, Qwen, Mistral, SmolLM, Granite, TinyLlama). For
    non-gated variants (e.g. Nemotron's squared-ReLU MLP) this over-counts
    MLP params by 50%, which is why the meta-tensor path is preferred.

    With hidden size \(H\), intermediate size \(I\), \(L\) layers,
    vocabulary \(V\), \(n_\text{kv}\) KV heads and per-head dim \(d\), the
    per-layer cost is

    \[
        \text{attn} = 2 H^2 + 2 H \, n_\text{kv} \, d, \qquad
        \text{mlp}  = 3 H I
    \]

    (full Q/O projections; K/V shrunk by GQA; gate/up/down for SwiGLU).
    Total parameters:

    \[
        N = V H + L (\text{attn} + \text{mlp}) + \begin{cases}
            0      & \text{tied embeddings} \\
            V H    & \text{untied LM head}
        \end{cases}
    \]

    References:
        - Ainslie, J. et al. "GQA: Training Generalized Multi-Query
          Transformer Models from Multi-Head Checkpoints" (2023) --
          reduced K/V projection shape. <https://arxiv.org/abs/2305.13245>
        - So, D. R. et al. "Primer: Searching for Efficient Transformers
          for Language Modeling" (2021) -- squared-ReLU MLP (Nemotron
          family), 2 projections; motivates the 50% over-count caveat
          above. <https://arxiv.org/abs/2109.08668>
    """
    H = getattr(autoconfig, "hidden_size", None)
    L = getattr(autoconfig, "num_hidden_layers", None)
    if not (H and L):
        return None
    V = getattr(autoconfig, "vocab_size", 32_000) or 32_000
    inter = getattr(autoconfig, "intermediate_size", None) or 4 * H
    n_heads = getattr(autoconfig, "num_attention_heads", None) or max(1, H // 64)
    kv_heads = getattr(autoconfig, "num_key_value_heads", None) or n_heads
    head_dim = getattr(autoconfig, "head_dim", None) or max(1, H // max(n_heads, 1))
    tied = bool(getattr(autoconfig, "tie_word_embeddings", False))

    # Q proj (H×H) + O proj (H×H) = 2H²; K proj + V proj shrunk by GQA: 2·H·kv_heads·d
    attn = 2 * H * H + 2 * H * kv_heads * head_dim
    # gate proj (H×I) + up proj (H×I) + down proj (I×H) = 3HI  (SwiGLU / GeGLU gated MLP)
    mlp = 3 * H * inter
    per_layer = attn + mlp
    embed = V * H  # token embedding table
    lm_head = 0 if tied else V * H  # unembedding; zero when tied to embed
    return embed + L * per_layer + lm_head

estimate_base_model_params(autoconfig)

Return (n_params, method) for the base model, or None if unknown.

method == "exact" means the meta-tensor path succeeded and the count is architecture-accurate. method == "approximate" means the shape formula was used as a fallback (see estimate_params_from_shape for its known error modes) and the caller should flag the downstream VRAM estimate as heuristic. Benchmarked fallback error on supported architectures: \(-22\%\) to \(+33\%\); hybrid Mamba-Transformer models (e.g. Nemotron-H) can drift further.

Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def estimate_base_model_params(autoconfig: PretrainedConfig) -> tuple[int, Literal["exact", "approximate"]] | None:
    r"""Return ``(n_params, method)`` for the base model, or ``None`` if unknown.

    ``method == "exact"`` means the meta-tensor path succeeded and the count
    is architecture-accurate. ``method == "approximate"`` means the shape
    formula was used as a fallback (see
    [estimate_params_from_shape][nemo_safe_synthesizer.preflight.checks.environment.estimate_params_from_shape]
    for its known error modes) and the caller should flag the downstream VRAM
    estimate as heuristic. Benchmarked fallback error on supported
    architectures: \(-22\%\) to \(+33\%\); hybrid Mamba-Transformer models
    (e.g. Nemotron-H) can drift further.
    """
    exact = param_count_from_empty_model(autoconfig)
    if exact is not None:
        return exact, "exact"
    approx = estimate_params_from_shape(autoconfig)
    if approx is None:
        return None
    return approx, "approximate"

bytes_per_base_weight(training_cfg)

Return expected bytes/param for the base model given PEFT mode.

NSS always trains via LoRA or QLoRA, so the base model's storage precision dominates VRAM (LoRA adapter params, gradients, and optimizer state are comparatively negligible).

  • QLoRA: \(\text{bits}/8 + 0.1\) to cover quant state (absmax / block scales) and dequant workspace. Yields \(\approx 0.6\) for 4-bit, \(\approx 1.1\) for 8-bit.
  • LoRA (unquantized): \(2\) bytes (bf16/fp16 base weights).
References
  • Hu, E. J. et al. "LoRA: Low-Rank Adaptation of Large Language Models" (2021) -- base weights frozen; adapter + gradients + optimizer state are small relative to \(N b\). https://arxiv.org/abs/2106.09685
  • Dettmers, T. et al. "QLoRA: Efficient Finetuning of Quantized LLMs" (2023) -- 4-bit NF4 quantization with block-wise absmax scales; the \(+0.1\) term accounts for these scales and the dequant workspace. https://arxiv.org/abs/2305.14314
Source code in src/nemo_safe_synthesizer/preflight/checks/environment.py
def bytes_per_base_weight(training_cfg: TrainingHyperparams) -> float:
    r"""Return expected bytes/param for the base model given PEFT mode.

    NSS always trains via LoRA or QLoRA, so the base model's storage
    precision dominates VRAM (LoRA adapter params, gradients, and
    optimizer state are comparatively negligible).

    - QLoRA: \(\text{bits}/8 + 0.1\) to cover quant state (absmax / block
      scales) and dequant workspace. Yields \(\approx 0.6\) for 4-bit,
      \(\approx 1.1\) for 8-bit.
    - LoRA (unquantized): \(2\) bytes (bf16/fp16 base weights).

    References:
        - Hu, E. J. et al. "LoRA: Low-Rank Adaptation of Large Language
          Models" (2021) -- base weights frozen; adapter + gradients +
          optimizer state are small relative to \(N b\).
          <https://arxiv.org/abs/2106.09685>
        - Dettmers, T. et al. "QLoRA: Efficient Finetuning of Quantized
          LLMs" (2023) -- 4-bit NF4 quantization with block-wise absmax
          scales; the \(+0.1\) term accounts for these scales and the
          dequant workspace. <https://arxiv.org/abs/2305.14314>
    """
    if training_cfg.peft_implementation.upper() == "QLORA":
        return training_cfg.quantization_bits / 8 + 0.1
    return 2.0