Skip to content

transform.easydel_transform

float_tensor_to_dtype(tensor, dtype)

The float_tensor_to_dtype function is used to convert a tensor's dtype to the specified dtype.

Parameters:

Name Type Description Default
tensor

Convert the tensor to a float dtype

required
dtype

Convert the tensor to a specific dtype

required

Returns:

Type Description

A tensor with the specified dtype

Source code in src/python/easydel/transform/easydel_transform.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def float_tensor_to_dtype(tensor, dtype):
    """
    The float_tensor_to_dtype function is used to convert a tensor's dtype to the specified dtype.

    :param tensor: Convert the tensor to a float dtype
    :param dtype: Convert the tensor to a specific dtype
    :return: A tensor with the specified dtype

    """
    if dtype is None or dtype == "":
        return tensor
    if isinstance(dtype, str):
        dtype = get_dtype(dtype)
    float_dtypes = (jax.numpy.bfloat16, jax.numpy.float16, jax.numpy.float32, jax.numpy.float64)
    if getattr(tensor, "dtype", None) in float_dtypes:
        tensor = tensor.astype(dtype)
    return tensor

huggingface_to_easydel(state_dict, *, device, embedding_layer_names=None, layer_norm_names=None, shard_fns=None, convert_to_8bit=False, params_pattern_selection=None, dtype=jax.numpy.float16, rnn_based_or_rwkv=False, verbose=True, remove_state_dict=False, **kwargs)

The huggingface_to_easydel function takes a huggingface model's state_dict and converts it to an easydel model's flax_dict. The function is designed to be used in conjunction with the load_huggingface function, which loads a huggingface model from disk. The embedding layer name must be specified as well as the device on which the conversion will take place.

Parameters:

Name Type Description Default
state_dict

Load the weights from a huggingface model

required
embedding_layer_names Optional[List[str]]

List[str]: Identify the embedding layer in the huggingface model

None
device

Determine which device the model will be loaded on

required
layer_norm_names Optional[List[str]]

Replaces weight or kernel with (scale)

None
shard_fns Optional[Mapping[tuple, Callable]]

Optional[Mapping[tuple, Callable]]: Sharding Function to be used to shard model

None
convert_to_8bit

bool: whenever to convert the into 8bit format

required
dtype dtype

jax.numpy.dtype: Specify the data type of the tensors

float16
rnn_based_or_rwkv bool

bool: rnn_based_or_rwkv is a conditioner which decide whenever it finds a value in tree that start with time_mix_ it will automatically reshape that for easydel use case

False
verbose bool

bool: whenever to log sharding or converting process

True
remove_state_dict bool

bool : whether to remove state dict during the transforming process

False

Returns:

Type Description

A dictionary of the weights and biases in a format that can be used by flax (it's an UnFlattenDict)

Source code in src/python/easydel/transform/easydel_transform.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def huggingface_to_easydel(
        state_dict,
        *,
        device,
        embedding_layer_names: Optional[List[str]] = None,
        layer_norm_names: Optional[List[str]] = None,
        shard_fns: Optional[Mapping[tuple, Callable]] = None,
        convert_to_8bit: bool = False,
        params_pattern_selection: Optional[re.Pattern] = None,
        dtype: jax.numpy.dtype = jax.numpy.float16,
        rnn_based_or_rwkv: bool = False,
        verbose: bool = True,
        remove_state_dict: bool = False,
        **kwargs
):
    """
    The huggingface_to_easydel function takes a huggingface model's state_dict and converts it to an easydel
    model's flax_dict. The function is designed to be used in conjunction with the load_huggingface function, which
    loads a huggingface model from disk. The embedding layer name must be specified as well as the device on which
    the conversion will take place.

    :param state_dict: Load the weights from a huggingface model
    :param embedding_layer_names: List[str]: Identify the embedding layer in the huggingface model
    :param device: Determine which device the model will be loaded on
    :param layer_norm_names: Replaces weight or kernel with (scale)
    :param shard_fns: Optional[Mapping[tuple, Callable]]: Sharding Function to be used to shard model
    :param convert_to_8bit : bool: whenever to convert the into 8bit format
    :param params_pattern_selection : Optional[re.Pattern]: patter to use to find the parameters of the model which will
    be converted to 8bit format.
    :param dtype: jax.numpy.dtype: Specify the data type of the tensors
    :param rnn_based_or_rwkv: bool: rnn_based_or_rwkv is a conditioner which decide whenever it finds a value in tree
    that start with time_mix_ it will automatically reshape that for easydel use case
    :param verbose:bool: whenever to log sharding or converting process
    :param remove_state_dict:bool : whether to remove state dict during the transforming process
    :return: A dictionary of the weights and biases in a format that can be used by flax (it's an UnFlattenDict)

    """
    embedding_layer_names = set(embedding_layer_names or [])
    layer_norm_names = set(layer_norm_names or [])
    _l = len(".weight")
    _b = len(".bias")

    if convert_to_8bit:
        assert params_pattern_selection is not None, (
            "in case of converting parameters to 8bit you should pass "
            "`params_pattern_selection` too, to tell the quantizer which parameters should be quantized."
        )

    with jax.default_device(device):
        flax_dict = {}
        pbar = tqdm(total=len(state_dict), disable=not verbose)

        pbar.set_description("Converting Model")

        for key, tensor in list(state_dict.items()):
            # Determine if renaming is necessary
            new_key = key
            if any(layer_name in key for layer_name in embedding_layer_names):
                new_key = key[:-_l] + ".embedding"
            elif rnn_based_or_rwkv and ("time_mix_" in key or "time_" in key):
                tensor = tensor.reshape(-1)
            elif any(layer_norm in key for layer_norm in layer_norm_names):
                new_key = key.replace(".weight", ".scale")
            elif "weight" in key:
                if len(tensor.shape) == 2:
                    tensor = tensor.transpose(0, 1)
                new_key = key.replace(".weight", ".kernel")

            key_tuple = tuple(new_key.split("."))
            # Convert tensor to jax.numpy.array without detaching and moving to CPU
            array = jax.lax.convert_element_type(jnp.asarray(tensor.cpu().detach().numpy()), dtype)
            if remove_state_dict:
                del tensor
                del state_dict[key]
            # Apply sharding functions if provided
            if shard_fns and key_tuple in shard_fns:                array = shard_fns[key_tuple](array)
            if convert_to_8bit:
                if params_pattern_selection.search("/".join(key_tuple)):
                    array = fjformer.linen.linen.LinearBitKernel(
                        *fjformer.linen.linen.quantize(array, int_dtype=jnp.int8)  # type: ignore
                    )
            flax_dict[key_tuple] = array

            # Update progress bar less frequently to reduce overhead
            pbar.update(1)
        pbar.close()
        gc.collect()
        return traverse_util.unflatten_dict(flax_dict)

match_keywords(string, ts, ns)

The match_keywords function takes a string, and two lists of strings. The first list is the "must-have" keywords, and the second list is the "not-allowed" keywords. It returns True if all the must-have keywords are in string, but none of not allowed are in it.

Parameters:

Name Type Description Default
string

Pass in the text that is being searched

required
ts

Specify the required keywords and ns is used to specify the non-required keywords

required
ns

Specify a list of negative keywords

required

Returns:

Type Description

True if all the keywords in ts are present and none of the

Source code in src/python/easydel/transform/easydel_transform.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def match_keywords(string, ts, ns):
    """
    The match_keywords function takes a string, and two lists of strings.
    The first list is the "must-have" keywords, and the second list is the "not-allowed" keywords.
    It returns True if all the must-have keywords are in string, but none of not allowed are in it.

    :param string: Pass in the text that is being searched
    :param ts: Specify the required keywords and ns is used to specify the non-required keywords
    :param ns: Specify a list of negative keywords
    :return: True if all the keywords in ts are present and none of the

    """
    for t in ts:
        if t not in string:
            return False
    for n in ns:
        if n in string:
            return False
    return True

read_ckpt(path, shard_fns=None, add_extra_past_fix=None)

The read_ckpt function reads a checkpoint file and returns the tensors in it.

Parameters:

Name Type Description Default
path [str, PathLike]

[str, os.PathLike]: Specify the path to the checkpoint file

required
shard_fns

Shard the tensors

None
add_extra_past_fix list

list: Add an extra past to the key

None

Returns:

Type Description

A dictionary of tensors

Source code in src/python/easydel/transform/easydel_transform.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def read_ckpt(path: [str, os.PathLike], shard_fns=None, add_extra_past_fix: list = None):
    """
    The read_ckpt function reads a checkpoint file and returns the tensors in it.

    :param path: [str, os.PathLike]: Specify the path to the checkpoint file
    :param shard_fns: Shard the tensors
    :param add_extra_past_fix: list: Add an extra past to the key
    :return: A dictionary of tensors

    """
    tensors = {}
    with open(path, "rb") as stream:
        unpacker = msgpack.Unpacker(stream, read_size=83886080, max_buffer_size=0)
        for key, value in unpacker:
            if add_extra_past_fix is not None:
                key = add_extra_past_fix + key
            key = tuple(key)
            tensor = from_bytes(None, value)
            if shard_fns is not None:
                tensor = shard_fns[key](tensor)
            tensors[key] = tensor
    return tensors

save_ckpt(train_state, path, gather_fns=None, float_dtype=None)

The save_ckpt function saves the state of a training run to disk.

Parameters:

Name Type Description Default
train_state

Store the current state of the training process

required
path

Specify the location of the checkpoint file

required
gather_fns

Specify a function that will be used to convert the tensor to bytes

None
float_dtype

Convert the tensor to a specific dtype

None

Returns:

Type Description

Nothing

Source code in src/python/easydel/transform/easydel_transform.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def save_ckpt(train_state, path, gather_fns=None, float_dtype=None):
    """
    The save_ckpt function saves the state of a training run to disk.

    :param train_state: Store the current state of the training process
    :param path: Specify the location of the checkpoint file
    :param gather_fns: Specify a function that will be used to convert the tensor to bytes
    :param float_dtype: Convert the tensor to a specific dtype
    :return: Nothing

    """

    train_state = to_state_dict(train_state)
    packer = msgpack.Packer()
    flatten_train_state = flatten_dict(train_state)
    if gather_fns is not None:
        gather_fns = flatten_dict(to_state_dict(gather_fns))

    with open(path, "wb") as stream:
        for key, value in flatten_train_state.items():
            if gather_fns is not None:
                value = gather_fns[key](value)
            value = float_tensor_to_dtype(value, float_dtype)
            stream.write(packer.pack((key, to_bytes(value))))