utils.tensor_utils
np2jax(array)
Convert Numpy Array to JAX Array
Source code in src/python/easydel/utils/tensor_utils.py
14 15 16 17 18 |
|
pt2jax(array)
Convert Pytorch Array to JAX Array
Source code in src/python/easydel/utils/tensor_utils.py
21 22 23 24 25 |
|
pt2np(array)
Convert Pytorch Array to Numpy Array
Source code in src/python/easydel/utils/tensor_utils.py
7 8 9 10 11 |
|