Skip to content

bits.int_numerics

Numerics for int8, int4, binary and other integer types.

IntNumerics

Bases: QNumerics, PyTreeNode

Numerics for int8, int4, binary, etc.

Source code in src/fjformer/bits/int_numerics.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class IntNumerics(numerics.QNumerics, flax.struct.PyTreeNode):
    """Numerics for int8, int4, binary, etc."""

    bits: int
    preserve_zero: bool
    preserve_max_val: bool
    clip: bool
    clip_gradient: bool
    round: bool
    noise_fn: Optional[stochastic_rounding.NoiseFn]
    dtype: Optional[Any] = None

    def get_edge_of_last_int_bucket(self):
        ret = 2.0 ** (self.bits - 1)
        if self.preserve_zero:
            # Lose one bucket.
            ret -= 0.5
        return ret

    def get_center_of_last_int_bucket(self):
        return self.get_edge_of_last_int_bucket() - 0.5

    def abs_val_mapped_to(self):
        if self.preserve_max_val:
            return self.get_center_of_last_int_bucket()
        else:
            return self.get_edge_of_last_int_bucket()

    def _get_fwd_clip_bound(self):
        # If we are not rounding, we just clip to bucket edges.
        fwd_clip_bound = self.get_edge_of_last_int_bucket()
        # If, after clip, we are rounding, we need to make sure that
        # we won't round values at the clip_bound away to the
        # non-existing bucket.
        if self.round:
            # Reducing fwd_clip_bound by any value in (0.0, 1.0) is correct.
            fwd_clip_bound -= 0.5
        return fwd_clip_bound

    def get_dtype(self):
        return self.dtype

    def fwd(self, x, context):
        """Forward pass."""
        assert self.bits <= 22, 'Too many bits, float32 has less precision.'
        # Maybe noise
        if self.noise_fn:
            assert context.key is not None, (
                'noise_fn is set, requestic stochastic rounding, but RNG was not '
                'passed in Context.key'
            )
            x = (x + self.noise_fn(x.shape, context.key)).astype(x.dtype)

        if self.clip:
            fwd_clip_bound = self._get_fwd_clip_bound()
            x = jnp.clip(x, -fwd_clip_bound, fwd_clip_bound)

        # Maybe round
        if self.round:
            # TODO(lew): Have bucket centers at 2*k + 1, not at halves.
            round_to_halves = not self.preserve_zero
            if round_to_halves:
                x = jnp.floor(x) + 0.5
            else:
                x = lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)

        return x

    def vjp_fwd(self, x, context):
        res = (x,)
        return self.fwd(x, context), res

    def vjp_bwd(self, res, grad):
        # Gradient of the clip function.
        # For boundary values we will have full gradient.
        # When using abs(max(x)) scaling, x is always in the interior, and the
        # gradient clip is always 1. So, we can always set clip_gradient to false.
        # However, other types of scaling may result in x being outside (i.e., there
        # is clipping). In that case it may be desirable to make the gradient zero.
        ret = grad
        if self.clip_gradient:
            (x,) = res
            clip_bound = self._get_fwd_clip_bound()
            ret *= (-clip_bound <= x) * (x <= clip_bound)
        return (ret, None)

fwd(x, context)

Forward pass.

Source code in src/fjformer/bits/int_numerics.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def fwd(self, x, context):
    """Forward pass."""
    assert self.bits <= 22, 'Too many bits, float32 has less precision.'
    # Maybe noise
    if self.noise_fn:
        assert context.key is not None, (
            'noise_fn is set, requestic stochastic rounding, but RNG was not '
            'passed in Context.key'
        )
        x = (x + self.noise_fn(x.shape, context.key)).astype(x.dtype)

    if self.clip:
        fwd_clip_bound = self._get_fwd_clip_bound()
        x = jnp.clip(x, -fwd_clip_bound, fwd_clip_bound)

    # Maybe round
    if self.round:
        # TODO(lew): Have bucket centers at 2*k + 1, not at halves.
        round_to_halves = not self.preserve_zero
        if round_to_halves:
            x = jnp.floor(x) + 0.5
        else:
            x = lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)

    return x