Skip to content

linear

linear

Per-sample gradient hook for nn.Linear (Opacus).

Registering this module with Opacus ensures per-sample gradients are computed correctly for Linear layers. Import this module for its side effect; do not call compute_linear_grad_sample directly.

Functions:

Name Description
compute_linear_grad_sample

Compute per-sample gradients for an nn.Linear layer.

compute_linear_grad_sample(layer, activations, backprops)

Compute per-sample gradients for an nn.Linear layer.

Used by Opacus for correct per-sample gradient accumulation. Converts activations and backprops to float for mixed-precision compatibility.

Parameters:

Name Type Description Default
layer Linear

The Linear layer being sampled.

required
activations list[Tensor]

List of activation tensors from the forward pass.

required
backprops Tensor

Backpropagated gradient tensor.

required

Returns:

Type Description
dict[Parameter, Tensor]

Dictionary mapping each trainable parameter (weight, bias) to its

dict[Parameter, Tensor]

per-sample gradient tensor of shape (batch, ...).

Source code in src/nemo_safe_synthesizer/privacy/dp_transformers/linear.py
@register_grad_sampler(nn.Linear)
def compute_linear_grad_sample(
    layer: nn.Linear, activations: list[torch.Tensor], backprops: torch.Tensor
) -> dict[nn.Parameter, torch.Tensor]:
    """Compute per-sample gradients for an ``nn.Linear`` layer.

    Used by Opacus for correct per-sample gradient accumulation. Converts
    activations and backprops to float for mixed-precision compatibility.

    Args:
        layer: The Linear layer being sampled.
        activations: List of activation tensors from the forward pass.
        backprops: Backpropagated gradient tensor.

    Returns:
        Dictionary mapping each trainable parameter (weight, bias) to its
        per-sample gradient tensor of shape ``(batch, ...)``.
    """
    activation = activations[0]
    ret = {}
    if layer.weight.requires_grad:
        gs = contract("n...i,n...j->nij", backprops.float(), activation.float())
        ret[layer.weight] = gs
    if layer.bias is not None and layer.bias.requires_grad:
        ret[layer.bias] = contract("n...k->nk", backprops.float())
    return ret