Qwen3-4B GRPO on GSM8K with slime on Modal
Qwen3-4B GRPO on GSM8K (colocated)
What slime is. slime is an RL
post-training framework that pairs Megatron (training) with SGLang
(rollouts) and orchestrates both with Ray. modal-training-gym’s
slime launcher wires that stack onto a Modal multi-node cluster.
What this tutorial does. GRPO-tunes Qwen3-4B against
GSM8K on 4 nodes ×
8×H100 with actor and rollout colocated on the same GPUs. GSM8K
is the canonical target for math-RL: short prompts, short answers,
and a deterministic correctness check. This is the “everything works
end-to-end” reference for the slime framework — a medium-scale RL
post-training run with slime’s built-in math reward (no custom
reward code). For a custom-reward example see
slime_haiku; for the shared
primitives (DatasetConfig, volumes, the 3-stage pipeline) see
quickstart.
What you’ll need.
- Access to Modal’s multi-node training preview (4 × 8×H100).
- A
wandbModal secret holding your W&B API key (the slime launcher mounts it automatically whenWandbConfigis present). - Patience: multi-hour run — use
modal run --detach.
What to watch.
- Weights & Biases under project
slime-grpo, groupqwen3-4b-gsm8k. Rollout reward should climb steadily; eval fires every 20 training steps (seeeval_intervalbelow) against GSM8K’s test split. - Modal dashboard — per-node GPU utilization and live logs. On a healthy run you’ll see SGLang warm up, then alternating rollout / training phases.
import modal
from modal_training_gym.common.dataset import DatasetConfigfrom modal_training_gym.common.models import Qwen3_4Bfrom modal_training_gym.common.wandb import WandbConfigfrom modal_training_gym.frameworks.slime import ( ModalConfig, SlimeConfig,)from modal_training_gym.frameworks.slime.config import DATA_PATHDefine the dataset
Section titled “Define the dataset”The non-obvious choices for GSM8K under slime:
input_key="messages"+apply_chat_template=True— the prompt column holds a list of chat messages; slime runs the model’s chat template over them before tokenizing. The upstreamzhuzilin/gsm8kmirror we load below is already in that shape.label_key="label"— the column slime scores against.rm_type="math"— selects slime’s built-in math correctness reward. It parses the boxed numeric answer out of the rollout and compares tolabel. No custom reward code needed.rollout_shuffle=True— matters for GRPO’s group-sampling stability.
class GSM8KDataset(DatasetConfig): input_key = "messages" label_key = "label" apply_chat_template = True rollout_shuffle = True rm_type = "math"
def __init__(self, data_path): self._data_path = str(data_path) self.prompt_data = f"{self._data_path}/gsm8k/train.parquet" self.eval_prompt_data = ["gsm8k", f"{self._data_path}/gsm8k/test.parquet"]
def prepare(self): import os
from datasets import load_dataset
os.makedirs(f"{self._data_path}/gsm8k", exist_ok=True) ds = load_dataset("zhuzilin/gsm8k") ds["train"].to_parquet(f"{self._data_path}/gsm8k/train.parquet") ds["test"].to_parquet(f"{self._data_path}/gsm8k/test.parquet")Define the experiment
Section titled “Define the experiment”SlimeConfig is a pydantic dataclass — hover over any field in
your IDE to see its type, default, and description. Only non-default
values need to be specified; everything else inherits sensible
defaults.
Cluster
actor_num_nodes=4— 32 H100s (4 × 8 GPUs).colocate=True— actor and rollout share the same GPUs.
Throughput
use_dynamic_batch_size=True+max_tokens_per_gpu=9216— pack prompts up to a per-GPU token budget.recompute_granularity="full"— activation recomputation for memory savings.
RL
use_kl_loss=True— KL divergence is computed (butkl_loss_coefdefaults to 0.0, so it’s tracked but not penalized).weight_decay=0.1,adam_beta2=0.98— optimizer overrides.
base_model = Qwen3_4B()my_training_run = SlimeConfig( model=base_model, dataset=GSM8KDataset(DATA_PATH), wandb=WandbConfig(project="slime-grpo", group="qwen3-4b-gsm8k"), ref_load=base_model.model_name, actor_num_nodes=4, modal=ModalConfig(gpu="H100"),)Build and run
Section titled “Build and run”build_app() returns a Modal app with download_model,
prepare_dataset, and train. (Bridge mode means there’s no
separate convert_checkpoint step to call — see quickstart for the
general pattern.)
app = my_training_run.build_app()Related API Reference
Section titled “Related API Reference”Source: tutorials/rl/slime_gsm8k/slime_gsm8k.py
| Open in Modal Notebook