Skip to content

bits.numerics

Base abstract class for all numerics.

QNumerics

Bases: PyTreeNode, ABC

Numerics for int8, int4, binary, etc.

Source code in src/fjformer/bits/numerics.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class QNumerics(flax.struct.PyTreeNode, abc.ABC):
    """Numerics for int8, int4, binary, etc."""

    # TODO(lew): Currently this is a part of API, only because it is used to set
    # it in test. Remove and leave only get_dtype(

    @abc.abstractmethod
    def get_dtype(self):
        pass

    @abc.abstractmethod
    def fwd(self, x, context):
        """Forward pass."""
        pass

    @abc.abstractmethod
    def abs_val_mapped_to(self):
        """The value returned is the end of quantization range.

        It could be the biggest value that can be represented by numerical format
        exactly. E.g. in case of int8, 127 . Or it could be edge of the last bucket.
        Edge in case of int8, 127.5
        """
        pass

    @abc.abstractmethod
    def vjp_fwd(self, x, context):
        pass

    @abc.abstractmethod
    def vjp_bwd(self, res, grad):
        pass

abs_val_mapped_to() abstractmethod

The value returned is the end of quantization range.

It could be the biggest value that can be represented by numerical format exactly. E.g. in case of int8, 127 . Or it could be edge of the last bucket. Edge in case of int8, 127.5

Source code in src/fjformer/bits/numerics.py
34
35
36
37
38
39
40
41
42
@abc.abstractmethod
def abs_val_mapped_to(self):
    """The value returned is the end of quantization range.

    It could be the biggest value that can be represented by numerical format
    exactly. E.g. in case of int8, 127 . Or it could be edge of the last bucket.
    Edge in case of int8, 127.5
    """
    pass

fwd(x, context) abstractmethod

Forward pass.

Source code in src/fjformer/bits/numerics.py
29
30
31
32
@abc.abstractmethod
def fwd(self, x, context):
    """Forward pass."""
    pass