Skip to content

transform.falcon

falcon_from_pretrained(model_id, device)

return: Weight or Params for easydel Model , Config

Source code in src/python/easydel/transform/falcon.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def falcon_from_pretrained(model_id, device):
    """
    return: Weight or Params for easydel Model , Config
    """
    # Requested By vwxyzjn at https://github.com/erfanzar/EasyDeL/issues/15#issue-1881044170
    config = FalconConfig.from_pretrained(model_id)
    model = FalconForCausalLM.from_pretrained(model_id)
    easydel_wights = falcon_convert_pt_to_flax_7b(
        state_dict=model.state_dict(),
        config=config,
        device=device
    )
    del model
    gc.collect()
    config.add_jax_args()
    return easydel_wights, config