Skip to content

Model training

We assume you have /workspace defined in your cluster config and that data and models will be downloaded to that folder.

Download data and convert to SFT format

Get the data from HuggingFace and convert it to the SFT JSONL format expected by the NeMo-RL SFT. This might take a while (depending on your network connection) and will use a significant amount of RAM.

from functools import partial
from datasets import load_dataset
from nemo_skills.prompt.utils import get_prompt

def apply_format(elem, prompt, is_tir):
    if is_tir:
        if 'Remaining code executions: ' not in elem['output']:
            assert 'You have run out of code executions!' in elem['output']
            total_code_executions = 1
        else:
            total_code_executions = int(elem['output'].split('Remaining code executions: ')[1].split()[0][0]) + 1
        elem['input'] = prompt.fill({'problem': elem['input'], 'total_code_executions': total_code_executions}, format_as_string=True)
    else:
        elem['input'] = prompt.fill({'problem': elem['input']}, format_as_string=True)
    elem['output'] = prompt.format_assistant_response(elem['output'])
    return elem

dataset = load_dataset("nvidia/OpenMathReasoning")

for inference_mode in ["cot", "tir", "genselect"]:
    dataset[inference_mode] = dataset[inference_mode].rename_column("problem", "input")
    dataset[inference_mode] = dataset[inference_mode].rename_column("generated_solution", "output")

    code_tags = None
    if inference_mode == 'cot':
        prompt_config = 'generic/math'
    if inference_mode == 'tir':
        prompt_config = 'openmath/tir'
        code_tags = 'openmath'
    if inference_mode == 'genselect':  # already formatted
        prompt_config = {'user': '{problem}'}
    prompt = get_prompt(prompt_config, tokenizer='Qwen/Qwen2.5-32B-Instruct', code_tags=code_tags, system_message="")
    func = partial(apply_format, prompt=prompt, is_tir=(inference_mode == 'tir'))
    dataset[inference_mode] = dataset[inference_mode].map(func, num_proc=20)

dataset["cot"].to_json("omr-cot.jsonl")
dataset["tir"].to_json("omr-tir.jsonl")
dataset["genselect"].to_json("omr-genselect.jsonl")

If you want to train on all the data, mix it together running the following commands

cat omr-cot.jsonl omr-tir.jsonl omr-genselect.jsonl > omr-all.jsonl
shuf -o omr-all.jsonl omr-all.jsonl

Prepare base model

Download the base model. We used the following base models

Here is an example of commands for Qwen2.5-Math-1.5B

pip install -U "huggingface_hub[cli]"
hf download Qwen/Qwen2.5-Math-1.5B --local-dir Qwen2.5-Math-1.5B

For 1.5B and 7B models we use "Math" models, so we also need to update their rope base and max positional embeddings. For 14B and 32B you should not do that!

sed -i 's/"max_position_embeddings": 4096,/"max_position_embeddings": 131072,/g' Qwen2.5-Math-1.5B/config.json
sed -i 's/"rope_theta": 10000,/"rope_theta": 500000.0,/g' Qwen2.5-Math-1.5B/config.json

Run training

Run the training (assuming slurm configuration here with the same folder structure). If your cluster has strict timeout policy, you can run multiple dependent jobs with --dependent_jobs=N.

ns nemo_rl sft \
    --cluster=slurm \
    --expname=openmathreasoning-repro-1.5b \
    --output_dir=/workspace/openmathreasoning-sft/checkpoints \
    --hf_model=/workspace/Qwen2.5-Math-1.5B  \
    --num_nodes=64 \
    --num_gpus=8 \
    --backend=megatron \
    --average_steps=7500,15000,22500,30000 \
    --training_data=/workspace/openmathreasoning-sft/omr-all.jsonl \
    ++policy.max_total_sequence_length=32768 \
    ++policy.train_micro_batch_size=1 \
    ++policy.train_global_batch_size=1024 \
    ++policy.tensor_model_parallel_size=1 \
    ++policy.context_parallel_size=2 \
    ++policy.lr=3e-4 \
    ++policy.min_lr=3e-7 \
    ++policy.megatron_cfg.scheduler.lr_warmup_iters=3000 \
    ++policy.megatron_cfg.scheduler.lr_warmup_init=0 \
    ++checkpointing.save_period=7500 \
    ++sft.max_num_steps=30000 \
    ++sft.max_num_epochs=100
lr min_lr TP CP
Qwen2.5-Math-1.5B 3e-4 3e-7 1 2
Qwen2.5-Math-7B 2e-4 2e-7 4 2
Qwen2.5-14B 1e-4 1e-7 8 2
Qwen2.5-32B 1e-4 1e-7 8 4

If you want to follow up with checkpoint conversion and evaluation, see training docs for an example of how to do it through a convenient Python API.

Second-round SFT

Note

After release we realized that we didn't do filtering for TIR and GenSelect subsets. If you want to reproduce our results exactly, modify the code below to only apply filtering on the CoT subset and use original TIR and GenSelect subsets. In this case also change training duration to be 10000 steps and update average steps and warmup accordingly.

For best results though, we recommend doing filtering on all subsets. To do that, run the commands below without changes.

In our paper we also did a second round SFT for all models except 32B. All the commands stay the same except the following changes to initial data preparation as well as a change to train for 3000 steps instead of 30000 used in the first-round SFT.

    --hf_model=/workspace/openmathreasoning-sft/final_hf_model \
    --training_data=<path to the new data> \
    --average_steps=750,1500,2250,3000 \
    ++policy.megatron_cfg.scheduler.lr_warmup_iters=300 \
    ++policy.megatron_cfg.scheduler.lr_warmup_init=0 \
    ++sft.max_num_steps=3000

Here is the code that can be used to prepare the second-round SFT data

from functools import partial
from datasets import load_dataset
from transformers import AutoTokenizer
from nemo_skills.prompt.utils import get_prompt

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-14B")

def apply_format(elem, prompt, is_tir):
    if is_tir:
        if 'Remaining code executions: ' not in elem['output']:
            assert 'You have run out of code executions!' in elem['output']
            total_code_executions = 1
        else:
            total_code_executions = int(elem['output'].split('Remaining code executions: ')[1].split()[0][0]) + 1
        elem['input'] = prompt.fill({'problem': elem['input'], 'total_code_executions': total_code_executions}, format_as_string=True)
    else:
        elem['input'] = prompt.fill({'problem': elem['input']}, format_as_string=True)
    elem['output'] = prompt.format_assistant_response(elem['output'])
    return elem

def filter_func(example, inference_mode):
    olymp_sources = ['aops_c5_contests_amp_programs', 'aops_c6_high_school_olympiads']
    if example['problem_source'] not in olymp_sources:
        return False
    if example['pass_rate_72b_tir'] == 'n/a' or float(example['pass_rate_72b_tir']) > 0.3:
        return False
    if inference_mode == 'genselect':  # no length-based filtering for genselect
        return True
    return len(tokenizer.encode(example['output'])) >= 5000

dataset = load_dataset("nvidia/OpenMathReasoning")

for inference_mode in ["cot", "tir", "genselect"]:
    dataset[inference_mode] = dataset[inference_mode].rename_column("problem", "input")
    dataset[inference_mode] = dataset[inference_mode].rename_column("generated_solution", "output")

    code_tags = None
    if inference_mode == 'cot':
        prompt_config = 'generic/math'
    if inference_mode == 'tir':
        prompt_config = 'openmath/tir'
        code_tags = 'openmath'
    if inference_mode == 'genselect':  # already formatted
        prompt_config = {'user': '{problem}'}
    func = partial(filter_func, inference_mode=inference_mode)
    dataset[inference_mode] = dataset[inference_mode].filter(func, num_proc=20)
    prompt = get_prompt(prompt_config, tokenizer='Qwen/Qwen2.5-32B-Instruct', code_tags=code_tags, system_message="")
    func = partial(apply_format, prompt=prompt, is_tir=(inference_mode == 'tir'))
    dataset[inference_mode] = dataset[inference_mode].map(func, num_proc=20)

dataset["cot"].to_json("omr-cot-round2.jsonl")
dataset["tir"].to_json("omr-tir-round2.jsonl")
dataset["genselect"].to_json("omr-genselect-round2.jsonl")

Since the data is relatively small, you don't need to split it and can pack the full file directly.