modules.flax_modelling_utils
BaseJAXAttentionModule
Bases: Module
Source code in src/python/easydel/modules/flax_modelling_utils.py
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 |
|
add_start_docstrings(*docstr)
The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator. The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
docstr |
Pass in a variable number of arguments to the function |
()
|
Returns:
Type | Description |
---|---|
A decorator that adds the docstrings to the function |
Source code in src/python/easydel/modules/flax_modelling_utils.py
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
|
apply_rotary_pos_emb(tensor, sin_, cos_)
The apply_rotary_pos_emb function applies a rotary positional embedding to the input tensor. b,h,s,d or pytorch style
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Store the tensor that is passed into the function |
required | |
sin_ |
Rotate the tensor by pi/2 |
required | |
cos_ |
Apply the cosine function to the tensor |
required |
Returns:
Type | Description |
---|---|
A tensor with the same shape as the input tensor |
Source code in src/python/easydel/modules/flax_modelling_utils.py
301 302 303 304 305 306 307 308 309 310 311 312 313 |
|
canonicalize_dtype(*args, dtype=None, inexact=True)
Canonicalize an optional dtype to the definitive dtype.
If the dtype
is None this function will infer the dtype. If it is not
None it will be returned unmodified or an exceptions is raised if the dtype
is invalid.
from the input arguments using jnp.result_type
.
Args:
args: JAX array compatible values. None values
are ignored.
dtype: Optional dtype override. If specified the arguments are cast to
the specified dtype instead and dtype inference is disabled.
inexact: When True, the output dtype must be a subdtype
of jnp.inexact
. Inexact dtypes are real or complex floating points. This
is useful when you want to apply operations that don't work directly on
integers like taking a mean for example.
Returns:
The dtype that args should be cast to.
Source code in src/python/easydel/modules/flax_modelling_utils.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
create_mesh(axis_dims=(1, -1, 1, 1), axis_names=('dp', 'fsdp', 'tp', 'sp'), backend='')
The create_mesh function creates a mesh object that can be used to shard arrays.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
axis_dims |
Sequence[int]
|
Sequence[int]: Specify the dimensions of the mesh |
(1, -1, 1, 1)
|
axis_names |
Sequence[str]
|
Sequence[str]: Name the axes of the mesh |
('dp', 'fsdp', 'tp', 'sp')
|
backend |
Specify the backend to use |
''
|
Returns:
Type | Description |
---|---|
A mesh object |
Source code in src/python/easydel/modules/flax_modelling_utils.py
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
|
get_dot_general_by_bits(bits=None, mode=EasyMethod.TRAIN)
The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object with the specified number of bits for forward and backward passes. If no bits are specified, the function returns None.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bits |
Optional[int]
|
Optional[int]: Specify the number of bits for quantization |
None
|
mode |
Literal['train', 'serve', 'convert']
|
EasyMethod: Specify the use of model to init the QDot Method for (e.q TRAIN,SERVE,...) |
TRAIN
|
Returns:
Type | Description |
---|---|
dict
|
A dict that contain dot_general_cls |
Source code in src/python/easydel/modules/flax_modelling_utils.py
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 |
|
get_gradient_checkpoint_policy(name)
The get_gradient_checkpoint_policy function is a helper function that returns the gradient checkpoint policy specified by the name parameter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Select the checkpoint policy from the dictionary |
required |
Returns:
Type | Description |
---|---|
A function that is used in the jax |
Source code in src/python/easydel/modules/flax_modelling_utils.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
|
get_names_from_partition_spec(partition_specs)
The get_names_from_partition_spec function takes a partition_specs argument, which is either a dictionary or list. If it's a dictionary, the function converts it to a list of values. Then for each item in the partition_specs list: If the item is None, continue (do nothing) and move on to next iteration of loop. If the item is an instance of str (i.e., if it's just one string), add that string to names set and move on to next iteration of loop. Otherwise, (if not None or str), call get_names_from_partition_spec recurs
Parameters:
Name | Type | Description | Default |
---|---|---|---|
partition_specs |
Define the partitioning of a table |
required |
Returns:
Type | Description |
---|---|
A list of the names of all partitions |
Source code in src/python/easydel/modules/flax_modelling_utils.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
|
get_ranks_and_size(mesh)
The get_ranks_and_size function is used to determine the number of MPI processes
(mp_node_size
) and the number of devices per process (dp_node_size
).
The mesh.shape[mp]
determines how many MPI processes are needed,
and then we divide that by the local device count to get `mp_node_size = max( 1, mp / jax.local )
.
This means that if there are more than enough devices for all MPI ranks on a node, each rank will only use one device; otherwise it will use
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh |
Get the shape of the mesh |
required |
Returns:
Type | Description |
---|---|
A dictionary with the following keys: |
Source code in src/python/easydel/modules/flax_modelling_utils.py
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
|
names_in_mesh(*names)
The names_in_mesh function is a decorator that can be used to check whether the names of the axes passed into a function are valid. It will raise an exception if any of the axis names are not in the physical mesh. For example, if you have a function that takes two axes as arguments, and you want to make sure they're both in your mesh:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
names |
Collect all the names passed to the function into a tuple |
()
|
Returns:
Type | Description |
---|---|
A boolean indicating whether all the given |
Source code in src/python/easydel/modules/flax_modelling_utils.py
97 98 99 100 101 102 103 104 105 106 107 108 |
|
repeat_kv_bnsh(x, n_rep)
The repeat_kv_bnsh function is used to repeat the key and value vectors for each head in a multi-head attention module. This function takes as input an array of shape (batch_size, n_heads, sequence_length, head_dim) and returns an array of shape (batch_size, n_heads * nrep, sequence length, head dim). The reason this is necessary is because the attention module expects keys/values/queries to be repeated across heads but not across batches. However we want our keys/values/queries to be repeated both across heads AND batches so that we can use them
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Array
|
chex.Array: Pass in the input to the function |
required |
n_rep |
int
|
int: Repeat the key and value heads |
required |
Returns:
Type | Description |
---|---|
Array
|
A new array with the same shape as x, except for the second dimension which is n_kv_heads * n_rep |
Source code in src/python/easydel/modules/flax_modelling_utils.py
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
|
repeat_kv_bsnh(x, n_rep)
The repeat_kv_bsnh function is used to repeat the key and value vectors for each head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Array
|
chex.Array: Specify the input array |
required |
n_rep |
int
|
int: Repeat the key-value attention heads n_rep times |
required |
Returns:
Type | Description |
---|---|
Array
|
A new array with the same batch size, sequence length, and head dimension as the input array |
Source code in src/python/easydel/modules/flax_modelling_utils.py
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
|
rotate_half(x)
The rotate_half function takes a complex-valued array and rotates the phase of its second half by 180 degrees. This is equivalent to multiplying the second half by -i, or equivalently rotating it 90 degrees counterclockwise.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Specify the input array |
required |
Returns:
Type | Description |
---|---|
A new array that is the same as the input |
Source code in src/python/easydel/modules/flax_modelling_utils.py
285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
|