Generates uniform number in [-0.5, 0.5].
Source code in src/fjformer/bits/stochastic_rounding.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44 | def random_centered_uniform(
shape: tuple[int, ...], key: jax.Array
) -> jnp.ndarray:
"""Generates uniform number in [-0.5, 0.5]."""
dtype = jnp.dtype('uint16')
nbits = jnp.iinfo(dtype).bits
# Generate random bits.
bits = jax.random.bits(key, shape, dtype)
# Align bits with the mantissa of f32.
nmant = jnp.finfo(jnp.float32).nmant
r_bitpattern = jnp.uint32(bits) << (nmant - nbits)
r_bitpattern = r_bitpattern | jnp.float32(1).view(jnp.uint32)
assert r_bitpattern.dtype == jnp.uint32
# Gen random floats and shift
rand_floats = jax.lax.bitcast_convert_type(r_bitpattern, jnp.float32)
shift = 2 ** (-1 - nbits)
centered = rand_floats - (1.5 - shift)
return centered
|