Skip to content

transform.mpt

mpt_from_pretrained(model_id, device, **kwargs)

return: Weight or Params for easydel Model , Config

Source code in src/python/easydel/transform/mpt.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def mpt_from_pretrained(model_id, device, **kwargs):
    """
    return: Weight or Params for easydel Model , Config
    """
    config = MptConfig.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, **kwargs)

    easydel_wights = mpt_convert_pt_to_flax_7b(
        state_dict=model.state_dict(),
        n_layers=config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else config.n_layers,
        device=device
    )
    config.add_jax_args()

    del model
    gc.collect()
    return easydel_wights, config