Skip to content

transform.llama

llama_easydel_to_hf(path, config)

Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)

Source code in src/python/easydel/transform/llama.py
150
151
152
153
154
155
156
157
158
159
160
def llama_easydel_to_hf(path, config: LlamaConfig):
    """
        Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)
    """
    torch_params = load_and_convert_checkpoint_to_torch(path)
    edited_params = {}
    for k, v in torch_params.items():
        edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v
    model = LlamaForCausalLM(config=config)
    model.load_state_dict(edited_params)
    return model

llama_from_pretrained(model_id, device)

return: Weight or Params for easydel Model , Config

Source code in src/python/easydel/transform/llama.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def llama_from_pretrained(model_id, device):
    """
    return: Weight or Params for easydel Model , Config
    """
    config = LlamaConfig.from_pretrained(model_id)
    model = LlamaForCausalLM.from_pretrained(model_id)
    easydel_wights = llama_convert_hf_to_flax(
        state_dict=model.state_dict(),
        config=config,
        device=device
    )
    config.add_jax_args()

    del model
    gc.collect()
    return easydel_wights, config