Algorithm Baselines
GSM8k
Assuming GSM8k dataset is preprocess via python3 examples/data_preprocess/gsm8k.py
Refer to the table below to reproduce PPO training from different pre-trained models.
NVIDIA GPUs
Model |
Method |
Test score |
Details |
|
|---|---|---|---|---|
google/gemma-2-2b-it |
pretrained checkpoint |
23.9 |
||
google/gemma-2-2b-it |
SFT |
52.06 |
||
google/gemma-2-2b-it |
SFT + PPO |
64.02 |
||
Qwen/Qwen2.5-0.5B-Instruct |
pretrained checkpoint |
36.4 |
||
Qwen/Qwen2.5-0.5B-Instruct |
PPO |
56.7 |
||
Qwen/Qwen2.5-0.5B-Instruct |
PRIME |
58.7 |
||
deepseek-ai/deepseek-llm-7b-chat |
PPO |
69.5 [1] |
||
Qwen/Qwen2-7B-Instruct |
GRPO |
89 |
||
Qwen/Qwen2.5-7B-Instruct |
ReMax |
97 |
||
AMD GPUs (MI300)
Model |
Method |
Test score |
Details |
deepseek-ai/deepseek-llm-7b-chat |
PPO |
70.5 [1] |
|
deepseek-ai/deepseek-llm-7b-chat |
GRPO |
71.4 [1] |