Skip to content

bits.calibration

Quantization calibration methods.

AbsMaxCalibration

Bases: Calibration

Simple max(abs(x)) calibration.

Source code in src/fjformer/bits/calibration.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@flax.struct.dataclass
class AbsMaxCalibration(Calibration):
    """Simple max(abs(x)) calibration."""

    def get_bound(self, x, shared_axes) -> jnp.ndarray:
        """Calibration."""
        msg = 'Perhaps you are using fake_quant and forgot to set them.'
        assert shared_axes is not None, msg

        # NOTE: If you want to clip, consider using clip and clip_gradient in
        # int_numerics.IntNumerics.
        abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True)
        abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max)
        return abs_max

get_bound(x, shared_axes)

Calibration.

Source code in src/fjformer/bits/calibration.py
44
45
46
47
48
49
50
51
52
53
def get_bound(self, x, shared_axes) -> jnp.ndarray:
    """Calibration."""
    msg = 'Perhaps you are using fake_quant and forgot to set them.'
    assert shared_axes is not None, msg

    # NOTE: If you want to clip, consider using clip and clip_gradient in
    # int_numerics.IntNumerics.
    abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True)
    abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max)
    return abs_max

ConstantCalibration

Bases: Calibration

Source code in src/fjformer/bits/calibration.py
29
30
31
32
33
34
35
36
37
@flax.struct.dataclass
class ConstantCalibration(Calibration):
    bound: Union[jnp.ndarray, float]

    def get_bound(self, x, shared_axes) -> jnp.ndarray:
        """Calibration."""
        del shared_axes
        assert self.bound > 0, 'Bound should be positive.'
        return jnp.asarray(self.bound).reshape((1,) * len(x.shape))

get_bound(x, shared_axes)

Calibration.

Source code in src/fjformer/bits/calibration.py
33
34
35
36
37
def get_bound(self, x, shared_axes) -> jnp.ndarray:
    """Calibration."""
    del shared_axes
    assert self.bound > 0, 'Bound should be positive.'
    return jnp.asarray(self.bound).reshape((1,) * len(x.shape))