Skip to content
GitHub

Multi-turn RL: guess a number from 1 to 20

Multi-turn number-guessing RL with custom generate and reward functions

This tutorial builds a small multi-turn environment inspired by PipelineRL’s guessing task. The model must guess a hidden number in [1, 20].

Loop per rollout:

  1. Model emits a guess as <answer>N</answer>.
  2. Environment returns <feedback>higher</feedback>, <feedback>lower</feedback>, or success.
  3. Repeat for a fixed turn budget.

We train with:

  • custom_generate_function: runs the interaction loop.
  • custom_rm_function: rewards correct answers with an early-turn bonus.
  • loss_mask: trains only on model-generated tokens, not environment feedback.
import json
import re
from modal_training_gym import (
DatasetConfig,
DeploymentConfig,
EvalConfig,
EvalRowResult,
ModelDeployment,
Qwen3_4B,
SlimeRecipe,
TrainConfig,
list_checkpoints,
)

Keep it simple:

  • train on odd targets
  • evaluate on even targets
_MAX_VALUE = 20
_MAX_TURNS = 6
_PROMPT = (
"You are playing a number guessing game.\n"
"The hidden integer is between 1 and 20.\n"
"Return only guesses in this exact format: <answer>N</answer> "
"where N is an integer between 1 and 20.\n"
"After each guess, you will receive <feedback>higher</feedback> or "
"<feedback>lower</feedback>, and must update your next guess accordingly."
)
TRAIN_TARGETS = list(range(1, _MAX_VALUE + 1, 2))
TEST_TARGETS = list(range(2, _MAX_VALUE + 1, 2))
class NumberGuessDataset(DatasetConfig):
input_key = "messages"
label_key = "label"
apply_chat_template = True
input_column = "prompt"
always_prepare = True # For the purpose of this tutorial, we want to prepare the dataset every time we run it, in case there is stale data from a previous run.
def load(self, split="all"):
targets = TRAIN_TARGETS if split == "train" else TEST_TARGETS
return [{"prompt": _PROMPT, "target": target} for target in targets]
def prepare(self, path: str, eval_paths: dict[str, str] | None = None):
import os
from datasets import Dataset
os.makedirs(os.path.dirname(path), exist_ok=True)
def _row(target: int) -> dict:
return {
"messages": [{"role": "user", "content": _PROMPT}],
"label": json.dumps({"answer": target}),
}
train_rows = [_row(target) for target in TRAIN_TARGETS for _ in range(20)]
eval_rows = [_row(target) for target in TEST_TARGETS]
Dataset.from_list(train_rows).to_parquet(path)
if eval_paths:
for eval_path in eval_paths.values():
os.makedirs(os.path.dirname(eval_path), exist_ok=True)
Dataset.from_list(eval_rows).to_parquet(eval_path)
train_dataset = NumberGuessDataset()
eval_dataset = NumberGuessDataset()

number_guess_generate is the environment loop:

  • model generates <answer>N</answer>
  • environment appends <feedback>higher|lower</feedback> when incorrect
  • only model text is trained (loss_mask=1), feedback is masked out (loss_mask=0)

Reward mirrors PipelineRL-style shaping:

  • success: 2.0 - 0.1 * (turns - 1)
  • malformed output: -2.0
  • otherwise: -1.0
_ANSWER_RE = re.compile(r"<answer>\s*(\d+)\s*</answer>", re.IGNORECASE)
def _parse_label(sample) -> dict:
raw = getattr(sample, "label", None)
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
return json.loads(raw)
except json.JSONDecodeError:
return {}
return {}
def _extract_answer(text: str) -> int | None:
matches = list(_ANSWER_RE.finditer(text))
if not matches:
return None
guess = int(matches[-1].group(1))
if 1 <= guess <= _MAX_VALUE:
return guess
return None
async def number_guess_generate(args, sample, sampling_params):
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.http_utils import post
from slime.utils.types import Sample
label = _parse_label(sample)
target = int(label.get("answer", 1))
max_turns = int(getattr(args, "max_turns", _MAX_TURNS))
state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
prompt_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
trajectory_text = ""
response_segments: list[tuple[str, int]] = []
success = False
format_error = False
turns_taken = max_turns
final_status = Sample.Status.COMPLETED
for turn in range(max_turns):
output = await post(
url,
{
"text": f"{sample.prompt}\n{trajectory_text}".strip(),
"sampling_params": sampling_params,
},
)
finish_type = output["meta_info"]["finish_reason"]["type"]
if finish_type == "abort":
sample.status = Sample.Status.ABORTED
return sample
model_text = output["text"]
trajectory_text += model_text
response_segments.append((model_text, 1))
guess = _extract_answer(model_text)
if guess is None:
format_error = True
turns_taken = turn + 1
break
if guess == target:
success = True
turns_taken = turn + 1
break
feedback = "higher" if guess < target else "lower"
feedback_text = f"\n<feedback>{feedback}</feedback>\n"
trajectory_text += feedback_text
response_segments.append((feedback_text, 0))
if finish_type == "length":
final_status = Sample.Status.TRUNCATED
break
response_token_ids: list[int] = []
loss_masks: list[int] = []
for segment_text, trainable in response_segments:
token_ids = state.tokenizer(
segment_text,
add_special_tokens=False,
)["input_ids"]
response_token_ids += token_ids
loss_masks += [trainable] * len(token_ids)
sample.tokens = prompt_ids + response_token_ids
sample.response_length = len(response_token_ids)
sample.response = trajectory_text
sample.loss_mask = loss_masks
sample.status = final_status
sample_metadata = getattr(sample, "metadata", None)
if not isinstance(sample_metadata, dict):
sample_metadata = {}
sample_metadata["guessing"] = {
"target": target,
"success": success,
"format_error": format_error,
"turns_taken": turns_taken,
}
sample.metadata = sample_metadata
return sample
def _trajectory_reward(success: bool, format_error: bool, turns_taken: int) -> float:
if success:
return float(2.0 - 0.1 * max(0, turns_taken - 1))
if format_error:
return -2.0
return -1.0
async def number_guess_rm(args, sample, **kwargs) -> float:
sample_metadata = getattr(sample, "metadata", None)
guessing_meta = sample_metadata.get("guessing", {}) if isinstance(sample_metadata, dict) else {}
success = bool(guessing_meta.get("success", False))
format_error = bool(guessing_meta.get("format_error", False))
turns_taken = int(guessing_meta.get("turns_taken", getattr(args, "max_turns", _MAX_TURNS)))
return _trajectory_reward(
success=success,
format_error=format_error,
turns_taken=turns_taken,
)

EvalConfig supports eval_fn, so we can plug in a full multi-turn evaluator per row while still using the standard eval runner.

def run_guessing_trajectory(
deployment: ModelDeployment,
*,
target: int,
max_turns: int = _MAX_TURNS,
) -> dict:
trace = ""
for turn in range(max_turns):
prompt = f"{_PROMPT}\n{trace}".strip()
response = deployment.generate(
prompt,
chat_template_kwargs={"enable_thinking": False},
)
guess = _extract_answer(response)
if guess is None:
return {
"success": False,
"format_error": True,
"turns_taken": turn + 1,
"response": response,
}
if guess == target:
return {
"success": True,
"format_error": False,
"turns_taken": turn + 1,
"response": response,
}
feedback = "higher" if guess < target else "lower"
trace += f"{response}\n<feedback>{feedback}</feedback>\n"
return {
"success": False,
"format_error": False,
"turns_taken": max_turns,
"response": trace,
}
def _resolve_target(example: dict) -> int:
if "target" in example:
return int(example["target"])
raw = example.get("label")
if isinstance(raw, dict):
return int(raw.get("answer", 1))
if isinstance(raw, str):
try:
payload = json.loads(raw)
except json.JSONDecodeError:
payload = {}
return int(payload.get("answer", 1))
return 1
def guessing_eval_fn(
deployment: ModelDeployment,
example: dict,
) -> EvalRowResult:
target = _resolve_target(example)
trajectory = run_guessing_trajectory(
deployment,
target=target,
max_turns=_MAX_TURNS,
)
reward = _trajectory_reward(
success=trajectory["success"],
format_error=trajectory["format_error"],
turns_taken=trajectory["turns_taken"],
)
return EvalRowResult(
score=reward,
response=trajectory["response"],
metadata={
"success": trajectory["success"],
"format_error": trajectory["format_error"],
"turns_taken": trajectory["turns_taken"],
"target": target,
},
)
def summarize_eval(eval_result) -> dict:
rows = eval_result.rows
success_rate = sum(1 for row in rows if row.metadata.get("success")) / max(
len(rows), 1
)
mean_turns = sum(
int(row.metadata.get("turns_taken", _MAX_TURNS)) for row in rows
) / max(len(rows), 1)
return {
"success_rate": float(success_rate),
"mean_turns": float(mean_turns),
}
base_deployment = DeploymentConfig(model=Qwen3_4B()).serve()
print(f"Base model URL: {base_deployment.url}")
eval_config = EvalConfig(
dataset=eval_dataset,
eval_fn=guessing_eval_fn,
)
base_eval = eval_config.evaluate(base_deployment, debug=True)
base_summary = summarize_eval(base_eval)
print(f"Base success rate: {base_summary['success_rate']:.2%}")
print(f"Base mean reward: {base_eval.mean:.3f}")
print(f"Base mean turns: {base_summary['mean_turns']:.2f}")

A quick tour of the SlimeRecipe knobs we set below.

Cluster and parallelism

  • gpu_type="H100" — GPU SKU used for both the rollout (sglang) and training (Megatron) ranks.
  • colocate=True — share the same GPUs between rollout and training, alternating between the two. Set False to give sglang dedicated GPUs (faster, more expensive).
  • tensor_model_parallel_size=1 — Megatron tensor-parallel degree. 1 keeps the 4B model on a single GPU; bump it for larger models that don’t fit.
  • sequence_parallel=False — only meaningful when tensor_model_parallel_size > 1.
  • rollout_num_gpus_per_engine=1 — GPUs per sglang inference engine (sglang’s TP).

Rollout

  • num_rollout=20 — total rollout/train iterations to run. Each iteration samples a batch, scores it, and applies one policy update.
  • rollout_batch_size=8 — prompts sampled per rollout iteration.
  • n_samples_per_prompt=1 — GRPO group size. 1 disables grouping; bump to ≥2 to get within-prompt advantage normalization.
  • rollout_max_response_len=64 — max new tokens per sglang call. We keep it tiny because every turn is <answer>N</answer> plus a bit of thinking.
  • rollout_temperature=1.0 — sampling temperature during rollouts.

Training and checkpoints

  • global_batch_size=8 — effective batch size for the policy gradient update.
  • save_interval=10 — write a Megatron checkpoint every N rollout iterations.
  • apply_chat_template_kwargs='{"enable_thinking": false}' — passed to the tokenizer’s chat template; disables Qwen3’s <think> block so responses stay short and parseable.
training_run = TrainConfig(
model=Qwen3_4B(),
dataset=train_dataset,
recipe=SlimeRecipe(
custom_generate_function=number_guess_generate,
custom_rm_function=number_guess_rm,
extra_config={
"max_turns": _MAX_TURNS,
"log_multi_turn": True,
},
gpu_type="H100",
colocate=True,
tensor_model_parallel_size=1,
sequence_parallel=False,
rollout_num_gpus_per_engine=1,
num_rollout=20,
rollout_batch_size=8,
n_samples_per_prompt=1,
rollout_max_response_len=64,
rollout_temperature=1.0,
global_batch_size=8,
save_interval=10,
apply_chat_template_kwargs='{"enable_thinking": false}',
),
)
print("Starting training...")
train_result = training_run.train()
print(f"Training run id: {train_result.training_run_id}")
checkpoint = list_checkpoints(train_result.training_run_id)[-1]
trained_deployment = DeploymentConfig(
model=Qwen3_4B(),
checkpoint=checkpoint,
app_name="qwen3-4b-guessing-multiturn-serve",
served_model_name="qwen3-4b-guessing-multiturn",
).serve()
print(f"Trained model URL: {trained_deployment.url}")
trained_eval = eval_config.evaluate(trained_deployment, debug=True)
trained_summary = summarize_eval(trained_eval)
print(f"Trained success rate: {trained_summary['success_rate']:.2%}")
print(f"Trained mean reward: {trained_eval.mean:.3f}")
print(f"Trained mean turns: {trained_summary['mean_turns']:.2f}")
print(f"Base success rate: {base_summary['success_rate']:.2%}")
print(f"Base mean reward: {base_eval.mean:.3f}")
print(f"Base mean turns: {base_summary['mean_turns']:.2f}")

Source: tutorials/rl/002_multiturn/002_multiturn.py | Open in Modal Notebook