Skip to content

transform.mistral

mistral_easydel_to_hf(path, config)

Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)

Source code in src/python/easydel/transform/mistral.py
250
251
252
253
254
255
256
257
258
259
260
def mistral_easydel_to_hf(path, config: MistralConfig):
    """
    Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)
    """
    torch_params = load_and_convert_checkpoint_to_torch(path)
    edited_params = {}
    for k, v in torch_params.items():
        edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v
    model = MistralForCausalLM(config=config)
    model.load_state_dict(edited_params)
    return model

mistral_from_pretrained(model_id, device)

return: Weight or Params for easydel Model , Config

Source code in src/python/easydel/transform/mistral.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def mistral_from_pretrained(model_id, device):
    """
    return: Weight or Params for easydel Model , Config
    """
    config = MistralConfig.from_pretrained(model_id)
    model = MistralForCausalLM.from_pretrained(model_id)
    easydel_wights = mistral_convert_hf_to_flax(
        state_dict=model.state_dict(),
        config=config,
        device=device
    )
    config.add_jax_args()

    del model
    gc.collect()
    return easydel_wights, config