Skip to content

bits.stochastic_rounding

Efficient stochastic rounding implementation.

random_centered_uniform(shape, key)

Generates uniform number in [-0.5, 0.5].

Source code in src/fjformer/bits/stochastic_rounding.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def random_centered_uniform(
        shape: tuple[int, ...], key: jax.Array
) -> jnp.ndarray:
    """Generates uniform number in [-0.5, 0.5]."""
    dtype = jnp.dtype('uint16')
    nbits = jnp.iinfo(dtype).bits

    # Generate random bits.
    bits = jax.random.bits(key, shape, dtype)

    # Align bits with the mantissa of f32.
    nmant = jnp.finfo(jnp.float32).nmant
    r_bitpattern = jnp.uint32(bits) << (nmant - nbits)
    r_bitpattern = r_bitpattern | jnp.float32(1).view(jnp.uint32)
    assert r_bitpattern.dtype == jnp.uint32

    # Gen random floats and shift
    rand_floats = jax.lax.bitcast_convert_type(r_bitpattern, jnp.float32)
    shift = 2 ** (-1 - nbits)
    centered = rand_floats - (1.5 - shift)

    return centered