Add models with the FSDP backend

Model

In principle, our FSDP backend can support any HF model and we can sychronoize the actor model weight with vLLM using hf_weight_loader.py under third_party/vllm. However, hf_weight_loader is will gather the full state_dict of a model during synchronization, which may cause OOM. We suggest using dtensor_weight_loader which gather the full model parameter layer by layer to reduce the peak memory usage. We already support dtensor weight loader for the models below in dtensor_weight_loader.py under third_party/vllm:

  • GPT2LMHeadModel

  • LlamaForCausalLM

  • LLaMAForCausalLM

  • MistralForCausalLM

  • InternLMForCausalLM

  • AquilaModel

  • AquilaForCausalLM

  • Phi3ForCausalLM

  • GemmaForCausalLM

  • Gemma2ForCausalLM

  • GPTBigCodeForCausalLM

  • Starcoder2ForCausalLM

  • Qwen2ForCausalLM

  • DeepseekV2ForCausalLM

To implement dtensor_weight_loader of a model that’s supported in vLLM, follow the guide of gemma model below:

  1. Copy the load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) from the vllm model class to dtensor_weight_loaders.py

  2. Modify the arguments to (actor_weights: Dict, vllm_model: nn.Module)

  3. Replace the self to vllm_model

  4. Add the local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) before each param = params_dict[name] and modify the following weight loading using local_loaded_weight.

  5. Register the implemented dtensor weight loader to __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.

- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]
-   params_dict = dict(self.named_parameters())
+   params_dict = dict(vllm_model.named_parameters())
    loaded_params = set()
-   for name, loaded_weight in weights:
+   for name, loaded_weight in actor_weights.items():
        for (param_name, shard_name, shard_id) in stacked_params_mapping:
            if shard_name not in name:
                continue
            name = name.replace(shard_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
+           local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
            param = params_dict[name]
            weight_loader = param.weight_loader
-           weight_loader(param, loaded_weight, shard_id)
+           weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
            break
        else:
            # lm_head is not used in vllm as it is tied with embed_token.
            # To prevent errors, skip loading lm_head.weight.
            if "lm_head.weight" in name:
                continue
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
+           local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
-           weight_loader(param, loaded_weight)
+           weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
        loaded_params.add(name)
    unloaded_params = params_dict.keys() - loaded_params
    if unloaded_params:
        raise RuntimeError(
            "Some weights are not initialized from checkpoints: "
            f"{unloaded_params}")