Algorithm Baselines

GSM8k

Assuming GSM8k dataset is preprocess via python3 examples/data_preprocess/gsm8k.py

Refer to the table below to reproduce PPO training from different pre-trained models.

NVIDIA GPUs

Model

Method

Test score

Details

google/gemma-2-2b-it

pretrained checkpoint

23.9

Huggingface

google/gemma-2-2b-it

SFT

52.06

SFT Command and Logs

google/gemma-2-2b-it

SFT + PPO

64.02

SFT+PPO Command and Logs, wandb

Qwen/Qwen2.5-0.5B-Instruct

pretrained checkpoint

36.4

Qwen Blog

Qwen/Qwen2.5-0.5B-Instruct

PPO

56.7

PPO Command and Logs

Qwen/Qwen2.5-0.5B-Instruct

PRIME

58.7

Qwen0.5b PRIME Script, Qwen0.5b PRIME Wandb

deepseek-ai/deepseek-llm-7b-chat

PPO

69.5 [1]

Megatron PPO Command and Logs, Megatron wandb

Qwen/Qwen2-7B-Instruct

GRPO

89

Qwen7b GRPO Script

Qwen/Qwen2.5-7B-Instruct

ReMax

97

Qwen7b ReMax Script, Qwen7b ReMax Wandb

AMD GPUs (MI300)

Model

Method

Test score

Details

deepseek-ai/deepseek-llm-7b-chat

PPO

70.5 [1]

ppo_run_deepseek7b_llm.sh

deepseek-ai/deepseek-llm-7b-chat

GRPO

71.4 [1]

grpo_run_deepseek7b_llm.sh