Skip to content

utils.tensor_utils

np2jax(array)

Convert Numpy Array to JAX Array

Source code in src/python/easydel/utils/tensor_utils.py
14
15
16
17
18
def np2jax(array: np.array) -> chex.Array:
    """
        Convert Numpy Array to JAX Array
        """
    return jnp.asarray(array)

pt2jax(array)

Convert Pytorch Array to JAX Array

Source code in src/python/easydel/utils/tensor_utils.py
21
22
23
24
25
def pt2jax(array: torch.Tensor) -> chex.Array:
    """
    Convert Pytorch Array to JAX Array
    """
    return np2jax(pt2np(array))

pt2np(array)

Convert Pytorch Array to Numpy Array

Source code in src/python/easydel/utils/tensor_utils.py
 7
 8
 9
10
11
def pt2np(array: torch.Tensor) -> np.array:
    """
        Convert Pytorch Array to Numpy Array
        """
    return array.detach().cpu().numpy()