transform.easydel_transform
float_tensor_to_dtype(tensor, dtype)
The float_tensor_to_dtype function is used to convert a tensor's dtype to the specified dtype.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Convert the tensor to a float dtype |
required | |
dtype |
Convert the tensor to a specific dtype |
required |
Returns:
Type | Description |
---|---|
A tensor with the specified dtype |
Source code in src/python/easydel/transform/easydel_transform.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
|
huggingface_to_easydel(state_dict, *, device, embedding_layer_names=None, layer_norm_names=None, shard_fns=None, convert_to_8bit=False, params_pattern_selection=None, dtype=jax.numpy.float16, rnn_based_or_rwkv=False, verbose=True, remove_state_dict=False, **kwargs)
The huggingface_to_easydel function takes a huggingface model's state_dict and converts it to an easydel model's flax_dict. The function is designed to be used in conjunction with the load_huggingface function, which loads a huggingface model from disk. The embedding layer name must be specified as well as the device on which the conversion will take place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict |
Load the weights from a huggingface model |
required | |
embedding_layer_names |
Optional[List[str]]
|
List[str]: Identify the embedding layer in the huggingface model |
None
|
device |
Determine which device the model will be loaded on |
required | |
layer_norm_names |
Optional[List[str]]
|
Replaces weight or kernel with (scale) |
None
|
shard_fns |
Optional[Mapping[tuple, Callable]]
|
Optional[Mapping[tuple, Callable]]: Sharding Function to be used to shard model |
None
|
|
convert_to_8bit
|
bool: whenever to convert the into 8bit format |
required |
dtype |
dtype
|
jax.numpy.dtype: Specify the data type of the tensors |
float16
|
rnn_based_or_rwkv |
bool
|
bool: rnn_based_or_rwkv is a conditioner which decide whenever it finds a value in tree that start with time_mix_ it will automatically reshape that for easydel use case |
False
|
verbose |
bool
|
bool: whenever to log sharding or converting process |
True
|
remove_state_dict |
bool
|
bool : whether to remove state dict during the transforming process |
False
|
Returns:
Type | Description |
---|---|
A dictionary of the weights and biases in a format that can be used by flax (it's an UnFlattenDict) |
Source code in src/python/easydel/transform/easydel_transform.py
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
match_keywords(string, ts, ns)
The match_keywords function takes a string, and two lists of strings. The first list is the "must-have" keywords, and the second list is the "not-allowed" keywords. It returns True if all the must-have keywords are in string, but none of not allowed are in it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
string |
Pass in the text that is being searched |
required | |
ts |
Specify the required keywords and ns is used to specify the non-required keywords |
required | |
ns |
Specify a list of negative keywords |
required |
Returns:
Type | Description |
---|---|
True if all the keywords in ts are present and none of the |
Source code in src/python/easydel/transform/easydel_transform.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
read_ckpt(path, shard_fns=None, add_extra_past_fix=None)
The read_ckpt function reads a checkpoint file and returns the tensors in it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
[str, PathLike]
|
[str, os.PathLike]: Specify the path to the checkpoint file |
required |
shard_fns |
Shard the tensors |
None
|
|
add_extra_past_fix |
list
|
list: Add an extra past to the key |
None
|
Returns:
Type | Description |
---|---|
A dictionary of tensors |
Source code in src/python/easydel/transform/easydel_transform.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
|
save_ckpt(train_state, path, gather_fns=None, float_dtype=None)
The save_ckpt function saves the state of a training run to disk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_state |
Store the current state of the training process |
required | |
path |
Specify the location of the checkpoint file |
required | |
gather_fns |
Specify a function that will be used to convert the tensor to bytes |
None
|
|
float_dtype |
Convert the tensor to a specific dtype |
None
|
Returns:
Type | Description |
---|---|
Nothing |
Source code in src/python/easydel/transform/easydel_transform.py
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
|