Skip to content

EasyState

EasyDeLState

EasyDeLState is a cool feature in easydel and have a lot of options like storing Model Parameters, Optimizer State, Model Config, Model Type, Optimizer and Scheduler Configs

Let see and examples of using EasyDeLState

Fine-tuning

Fine-tuning from a previous State or a new state

from easydel import (
    AutoEasyDeLConfig,
    EasyDeLState
)
from transformers import AutoTokenizer
from jax import numpy as jnp, lax
import jax

huggingface_model_repo_id = "REPO_ID"
checkpoint_name = "CKPT_NAME"

state = EasyDeLState.from_pretrained(
    pretrained_model_name_or_path=huggingface_model_repo_id,
    filename=checkpoint_name,
    optimizer="adamw",
    scheduler="none",
    tx_init=None,
    device=jax.devices('cpu')[0],  # Offload Device
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    precision=lax.Precision("fastest"),
    sharding_axis_dims=(1, -1, 1, 1),
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    query_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
    key_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
    value_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
    bias_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), None, None, None),
    attention_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
    shard_attention_computation=False,
    input_shape=(1, 1),
    backend=None,
    init_optimizer_state=False,
    free_optimizer_state=True,
    verbose=True,
    state_shard_fns=None,
)

config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_model_repo_id,
    trust_remote_code=True
)

max_length = config.max_position_embeddings

configs_to_initialize_model_class = {
    'config': config,
    'dtype': jnp.bfloat16,
    'param_dtype': jnp.bfloat16,
    'input_shape': (8, 8)
}

EasyDeLState also has .load_state() and .save_state() with some other usable options like .free_opt_state() which free optimizer state or .shard_params() which shard parameters you can read docs in order to find out more about these options.

Converting to Huggingface and Pytorch

Let see how you can convert a EasyDeLMistral Model to Huggingface Pytorch Mistral Model from a trained State


from transformers import MistralForCausalLM
from easydel import (
    AutoEasyDeLConfig,
    EasyDeLState,
    easystate_to_huggingface_model
)
import jax

huggingface_model_repo_id = "REPO_ID"

config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)
with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDeLState.load_state(
            "PATH_TO_CKPT",
            input_shape=(8, 2048)
        ),  # You can Pass EasyDeLState here
        base_huggingface_module=MistralForCausalLM,
        config=config,
    )

model = model.half()  # it's a huggingface model now

Other Use Cases

EasyDeLState have a general use you can use it everywhere in easydel for example for a stand-alone model , serve, fine-tuning and many other features, it's up to you to test how creative you are 😇.