partition_utils.mesh_utils
create_mesh(axis_dims=(1, -1, 1, 1), axis_names=('dp', 'fsdp', 'tp', 'sp'), backend='')
The create_mesh function creates a mesh object that can be used to shard arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis_dims |
Sequence[int]
|
Sequence[int]: Specify the dimensions of the mesh |
(1, -1, 1, 1)
|
axis_names |
Sequence[str]
|
Sequence[str]: Name the axes of the mesh |
('dp', 'fsdp', 'tp', 'sp')
|
backend |
Specify the backend to use |
''
|
Returns:
| Type | Description |
|---|---|
|
A mesh object |
Source code in src/fjformer/partition_utils/mesh_utils.py
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | |
flatten_tree(xs, is_leaf=None, sep=None)
The flatten_tree function takes a nested structure of arrays and returns a
dictionary mapping from string keys to the corresponding array values. The
string keys are derived from the tree path to each value, with sep used as
the separator between levels in the tree. For example:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
xs |
Store the tree structure |
required | |
is_leaf |
Determine if a node is a leaf |
None
|
|
sep |
Specify the separator between each key in the path |
None
|
Returns:
| Type | Description |
|---|---|
|
A dict of flattened tree paths to values |
Source code in src/fjformer/partition_utils/mesh_utils.py
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | |
get_jax_mesh(axis_dims, names)
The get_jax_mesh function takes a string of the form: <axis_dims> where axis_dims is a comma-separated list of dimensions, each dimension being either: <name>:<dim> or <dim> If there are no names, then the default names 'x', 'y', and 'z' will be used. If there are fewer than three dimensions, then the remaining dimensions will be set to 1. For example:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis_dims |
Specify the dimensions of the mesh |
required | |
names |
Specify the names of the dimensions in |
required |
Returns:
| Type | Description |
|---|---|
|
A mesh object |
Source code in src/fjformer/partition_utils/mesh_utils.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | |
get_metrics(metrics, unreplicate=False, stack=False)
The get_metrics function is a helper function that takes the metrics dictionary returned by the training loop and converts it to a format that can be used for plotting. It does this in two ways:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics |
Store the metrics that we want to track |
required | |
unreplicate |
Convert the metrics from a replicated |
False
|
|
stack |
Stack the metrics in a list |
False
|
Returns:
| Type | Description |
|---|---|
|
A dictionary of metrics |
Source code in src/fjformer/partition_utils/mesh_utils.py
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | |
get_names_from_partition_spec(partition_specs)
The get_names_from_partition_spec function takes a partition_specs argument, which is either a dictionary or list. If it's a dictionary, the function converts it to a list of values. Then for each item in the partition_specs list: If the item is None, continue (do nothing) and move on to next iteration of loop. If the item is an instance of str (i.e., if it's just one string), add that string to names set and move on to next iteration of loop. Otherwise (if not None or str), call get_names_from_partition_spec recurs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
partition_specs |
Specify the partitioning of the data |
required |
Returns:
| Type | Description |
|---|---|
|
A list of names |
Source code in src/fjformer/partition_utils/mesh_utils.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
get_weight_decay_mask(exclusions)
Return a weight decay mask function that computes the pytree masks according to the given exclusion rules.
Source code in src/fjformer/partition_utils/mesh_utils.py
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | |
make_shard_and_gather_fns(partition_specs, dtype_specs=None)
The make_shard_and_gather_fns function takes in a partition_specs and dtype_specs, and returns two functions: shard_fns and gather_fns. The shard function is used to shard the input tensor into the specified partitions. The gather function is used to gather all the shards back together into one tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
partition_specs |
Specify the sharding of the input tensor |
required | |
dtype_specs |
Specify the dtype of the tensor |
None
|
Returns:
| Type | Description |
|---|---|
|
A tuple of functions |
Source code in src/fjformer/partition_utils/mesh_utils.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | |
match_partition_rules(rules, params)
Returns a pytree of PartitionSpec according to rules. Supports handling Flax TrainState and Optax optimizer state.
Source code in src/fjformer/partition_utils/mesh_utils.py
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 | |
named_tree_map(f, tree, *rest, is_leaf=None, sep=None)
An extended version of jax.tree_util.tree_map, where the mapped function f takes both the name (path) and the tree leaf as input.
Source code in src/fjformer/partition_utils/mesh_utils.py
248 249 250 251 252 253 254 255 256 | |
names_in_current_mesh(*names)
The names_in_current_mesh function is used to check if a set of names are in the current mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*names |
Pass in a list of names to the function |
()
|
Returns:
| Type | Description |
|---|---|
|
A boolean indicating whether |
Source code in src/fjformer/partition_utils/mesh_utils.py
119 120 121 122 123 124 125 126 127 128 | |
tree_apply(fns, tree)
The tree_apply function is a generalization of the map function. It takes two arguments: a pytree of functions and a pytree of values. The tree_apply function applies each function in the first argument to its corresponding value in the second argument, and returns a new pytree with these results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fns |
Apply the functions to the tree |
required | |
tree |
Apply the function to each element in the tree |
required |
Returns:
| Type | Description |
|---|---|
|
A pytree of the same structure as the input |
Source code in src/fjformer/partition_utils/mesh_utils.py
293 294 295 296 297 298 299 300 301 302 303 304 | |
tree_path_to_string(path, sep=None)
The tree_path_to_string function takes a tree path and returns a string representation of it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path |
Specify the path of the tree |
required | |
sep |
Join the keys with a separator |
None
|
Returns:
| Type | Description |
|---|---|
|
A tuple of strings |
Source code in src/fjformer/partition_utils/mesh_utils.py
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | |
with_sharding_constraint(x, partition_specs)
A smarter version of with_sharding_constraint that only applies the constraint if the current mesh contains the axes in the partition specs.
Source code in src/fjformer/partition_utils/mesh_utils.py
157 158 159 160 161 162 163 164 | |
wrap_function_with_rng(rng)
To be used as decorator, automatically bookkeep a RNG for the wrapped function.
Source code in src/fjformer/partition_utils/mesh_utils.py
167 168 169 170 171 172 173 174 175 176 177 178 | |