Skip to content

vllm_backend

vllm_backend

vLLM-based generation backend for tabular data synthesis.

Classes:

Name Description
VllmBackend

Generation backend using vLLM for high-throughput inference.

VllmBackend(config, model_metadata, workdir, **kwargs)

Bases: GeneratorBackend

Generation backend using vLLM for high-throughput inference.

Loads the base model with a LoRA adapter via vLLM and generates synthetic records in batches. Supports optional structured generation (regex or JSON schema) to constrain outputs.

LoRARequest("lora", 1, str(adapter_path)) is passed to llm.generate when an adapter is available. The vLLM engine uses config.training.lora_r as max_lora_rank.

Parameters:

Name Type Description Default
config SafeSynthesizerParameters

Pipeline configuration.

required
model_metadata ModelMetadata

Model metadata (prompt template, adapter path, sequence length, etc.).

required
workdir Workdir

Working directory containing the adapter and schema.

required
**kwargs

Additional options. use_detailed_logs (bool) enables verbose error messages (disabled by default to avoid leaking sensitive data).

{}

Methods:

Name Description
teardown

Release GPU memory and distributed resources. Idempotent -- safe to call multiple times.

initialize

Initialize and load the model into memory.

prepare_params

Parse parameters and configure the generation method.

generate

Generate synthetic tabular data in batches until the target count is reached.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def __init__(
    self,
    config: SafeSynthesizerParameters,
    model_metadata: ModelMetadata,
    workdir: Workdir,
    **kwargs,
):
    self.model_metadata = model_metadata
    self.config = config
    self.remote = False
    self.workdir = workdir
    self.schema = load_json(self.workdir.schema_file)
    self.columns = list(self.schema["properties"].keys())
    self.prompt = utils.create_schema_prompt(
        self.columns,
        instruction=self.model_metadata.instruction,
        prompt_template=self.model_metadata.prompt_config.template,
    )
    self.llm: vLLM | None = None
    self._prompt_token_count: int | None = None

    # Do not generate detailed error messages in production to avoid leaking sensitive data.
    self.use_detailed_logs = kwargs.pop("use_detailed_logs", False)
    self.gen_method: partial | None = None
    self._gen_method: partial | None = None
    # Initial processor without a tokenizer; replaced in ``initialize()`` with a
    # tokenizer-aware processor once the vLLM engine (and its tokenizer) exists.
    # This lets callers introspect the processor type before ``initialize()``,
    # at the cost of token counts being zero until the tokenizer is attached.
    self.processor: Processor = create_processor(self.schema, self.model_metadata, self.config)
    adapter_path = self.workdir.adapter_path if self.workdir.adapter_path else self.model_metadata.adapter_path
    self.lora_req = LoRARequest("lora", 1, str(adapter_path)) if adapter_path else None
    self._torn_down = False

teardown()

Release GPU memory and distributed resources. Idempotent -- safe to call multiple times.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def teardown(self) -> None:
    """Release GPU memory and distributed resources. Idempotent -- safe to call multiple times."""
    if self._torn_down:
        return
    self._torn_down = True

    try:
        cleanup_dist_env_and_memory()
    except Exception:
        logger.debug("cleanup_dist_env_and_memory failed during teardown", exc_info=True)

    self.llm = None
    self._gen_method = None
    self.gen_method = None

    try:
        cleanup_memory()
    except Exception:
        logger.debug("cleanup_memory failed during teardown", exc_info=True)

initialize(**kwargs)

Initialize and load the model into memory.

Creates the vLLM engine and then builds the record processor with the engine's tokenizer so that exact token counts are available during generation.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def initialize(self, **kwargs) -> None:
    """Initialize and load the model into memory.

    Creates the vLLM engine and then builds the record processor
    with the engine's tokenizer so that exact token counts are
    available during generation.
    """
    self._torn_down = False

    # vLLM 0.12+ accepts attention_config as a constructor arg (replaces the
    # VLLM_ATTENTION_BACKEND env var used in 0.11.x).
    attn_backend = self.config.generation.attention_backend
    attention_config = {"backend": attn_backend} if attn_backend not in (None, "auto") else None

    max_vram = get_max_vram()
    # note this only works for single GPU setups
    max_vram = max_vram.get(0, 0.8)

    # vllm requires this "config" to set the backend ahead of time.
    structured_outputs_config = StructuredOutputsConfig(
        backend=self.config.generation.structured_generation_backend,
    )
    model_ref = ModelRef.parse(self.config.training.pretrained_model)

    with heartbeat("Model loading", logger_name=__name__, model=self.config.training.pretrained_model):
        self.llm = vLLM(
            model=model_ref.target(),
            gpu_memory_utilization=max_vram,
            enable_lora=True,
            max_lora_rank=self.config.training.lora_r,
            structured_outputs_config=structured_outputs_config,
            attention_config=attention_config,
            trust_remote_code=model_ref.trust_remote_code,
        )

    # vLLM's get_tokenizer() returns a wider union than HF's PreTrainedTokenizerBase;
    # in practice it's always a HF tokenizer subclass, so cast for the processor.
    tokenizer = cast(PreTrainedTokenizerBase, self.llm.get_tokenizer())
    self.processor = create_processor(
        self.schema,
        self.model_metadata,
        self.config,
        tokenizer=tokenizer,
    )

prepare_params(**kwargs)

Parse parameters and configure the generation method.

Parses a dictionary of parameters into SamplingParams, applying necessary transformations from the Safe Synthesizer API to vLLM's API. num_beams is mapped to beam_width only when greater than 1; otherwise it is omitted.

Parameters:

Name Type Description Default
**kwargs

Sampling parameters to configure.

{}
Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def prepare_params(self, **kwargs) -> None:
    """Parse parameters and configure the generation method.

    Parses a dictionary of parameters into ``SamplingParams``, applying
    necessary transformations from the Safe Synthesizer API to vLLM's API.
    ``num_beams`` is mapped to ``beam_width`` only when greater than 1;
    otherwise it is omitted.

    Args:
        **kwargs: Sampling parameters to configure.
    """
    structured_output_params = self._build_structured_output_params()
    kwargs |= {"structured_outputs": structured_output_params}

    resolved_temperature = self._resolve_temperature(kwargs)
    api_mapping = self._get_api_param_mapping(resolved_temperature)
    sampling_params = self._transform_kwargs_to_sampling_params(kwargs, api_mapping)

    real_params = SamplingParams(**sampling_params)
    logger.debug(f"SamplingParams: {real_params!r}")

    # Create a partially parametrized version of the underlying vllm.LLM.generate
    # method that is immediately callable downstream.
    if self.llm is None:
        raise InternalError(
            "VllmBackend._configure_sampling_params() called before initialize() -- self.llm is None."
        )
    self._gen_method = partial(
        self.llm.generate,
        sampling_params=real_params,
        lora_request=self.lora_req,
        # Show vLLM's tqdm progress bar only when debug logging is enabled.
        use_tqdm=logger.isEnabledFor(logging.DEBUG),
    )

generate(data_actions_fn=None)

Generate synthetic tabular data in batches until the target count is reached.

Iterates over generation batches, applying the processor to each LLM output, until the configured num_records target is met or a stopping condition fires.

Non-tabular processors need BOS/EOS delimiters in the raw text, so generation keeps special tokens for those processors and strips them only for TabularDataProcessor. Native EOS stopping remains enabled through ignore_eos=False.

Parameters:

Name Type Description Default
data_actions_fn DataActionsFn | None

Optional post-processing / validation function applied to each batch of generated records.

None

Returns:

Type Description
GenerateJobResults

Results containing the generated DataFrame and statistics.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def generate(
    self,
    data_actions_fn: utils.DataActionsFn | None = None,
) -> GenerateJobResults:
    """Generate synthetic tabular data in batches until the target count is reached.

    Iterates over generation batches, applying the processor to each
    LLM output, until the configured ``num_records`` target is met or
    a stopping condition fires.

    Non-tabular processors need BOS/EOS delimiters in the raw text, so
    generation keeps special tokens for those processors and strips them
    only for ``TabularDataProcessor``. Native EOS stopping remains enabled
    through ``ignore_eos=False``.

    Args:
        data_actions_fn: Optional post-processing / validation function
            applied to each batch of generated records.

    Returns:
        Results containing the generated DataFrame and statistics.
    """
    generation_start = time.monotonic()

    need_special_token_outputs = not isinstance(self.processor, TabularDataProcessor)
    sampling_kwargs = dict(
        temperature=self.config.generation.temperature,
        repetition_penalty=self.config.generation.repetition_penalty,
        top_p=self.config.generation.top_p,
        top_k=FIXED_RUNTIME_GENERATE_ARGS["top_k"],
        min_p=FIXED_RUNTIME_GENERATE_ARGS["min_p"],
        max_tokens=self.model_metadata.generation_max_tokens_for(self._get_prompt_token_count()),
        skip_special_tokens=not need_special_token_outputs,
        include_stop_str_in_output=need_special_token_outputs,
        ignore_eos=False,
    )

    self.prepare_params(**sampling_kwargs)

    # The batches object collects batches and keeps track of the stopping condition.
    batches = GenerationBatches(
        target_num_records=self.config.generation.num_records,
        invalid_fraction_threshold=self.config.generation.invalid_fraction_threshold,
        patience=self.config.generation.patience,
        data_actions_fn=data_actions_fn,
    )

    with heartbeat(
        "Generation",
        logger_name=__name__,
        target_records=self.config.generation.num_records,
        progress_note=("Long stretches with no new records are normal."),
    ):
        while batches.num_valid_records < self.config.generation.num_records:
            # Generate a batch from prompts and process the responses.
            num_prompts = batches.get_next_num_prompts()
            start_time = time.perf_counter()
            batch: Batch = self._generate_batch(
                num_prompts_per_batch=num_prompts,
                batch=Batch(processor=self.processor),
                **sampling_kwargs,
            )
            duration = time.perf_counter() - start_time
            batches.add_batch(batch)

            # Log generation summary and progress.
            batch.log_summary(detailed_errors=self.use_detailed_logs)
            self._log_batch_timing_and_progress(
                batch=batch,
                duration=duration,
                num_records=self.config.generation.num_records,
                num_valid_records=batches.num_valid_records,
                batches=batches,
            )
            # Check if the generation job should stop.
            if batches.status in [
                GenerationStatus.STOP_NO_RECORDS,
                GenerationStatus.STOP_METRIC_REACHED,
            ]:
                break

    batches.job_complete()
    batches.log_status()

    max_num_records = (
        self.config.generation.num_records
        if self.config.data.group_training_examples_by is None and batches.status == GenerationStatus.COMPLETE
        else None
    )

    generation_time_sec = time.monotonic() - generation_start
    self.elapsed_time = generation_time_sec
    self.gen_results = GenerateJobResults.from_batches(
        batches=batches,
        columns=self.columns,
        max_num_records=max_num_records,
        elapsed_time=self.elapsed_time,
    )

    return self.gen_results