Skip to content

bits.q_flax

Flax layer for AQT injection.

Freezer

Bases: Module

Identity function that can freeze its input.

On default it is an identity function that saves the input in a variable. In 'use_frozen=True' mode, ignores the input and returns the frozen value. It is usefult to implement 'constant folding' and put quantized weights and scales in the checkpoint for serving.

Source code in src/fjformer/bits/q_flax.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class Freezer(nn.Module):
    """Identity function that can freeze its input.

    On default it is an identity function that saves the input in a variable.
    In 'use_frozen=True' mode, ignores the input and returns the frozen value. It
    is usefult to implement 'constant folding' and put quantized weights and
    scales in the checkpoint for serving.
    """

    quant_collection: str
    quant_mode: QuantMode
    q_shape: Iterable[int]
    q_init: nn.initializers.Initializer
    s_shape: Iterable[int]
    s_init: nn.initializers.Initializer

    @nn.compact
    def __call__(
            self, inputs: Optional[q_dot_general.QTensor]
    ) -> Optional[q_dot_general.QTensor]:
        collection = self.quant_collection
        if inputs is None:  # getter mode
            if self.quant_mode == QuantMode.TRAIN:
                return inputs
            elif self.quant_mode == QuantMode.CONVERT:
                return inputs
            elif self.quant_mode == QuantMode.SERVE:
                # We could have created one self.variable whose value is a QTensor,
                # but this would complicate the init function, which could potentially
                # be used by adding metadata such as sharding axises, etc.
                qvalue = self.variable(collection, 'value', self.q_init, self.q_shape)
                scale = self.variable(collection, 'scale', self.s_init, self.s_shape)
                return q_dot_general.QTensor(qvalue.value, scale.value)
            else:
                assert False, 'Unknown quant mode.'
        else:  # setter mode
            if self.quant_mode == QuantMode.TRAIN:
                pass
            elif self.quant_mode == QuantMode.CONVERT:
                qvalue = self.variable(collection, 'value', self.q_init, self.q_shape)
                scale = self.variable(collection, 'scale', self.s_init, self.s_shape)
                qvalue.value = inputs.qvalue
                scale.value = inputs.qvalue_scale_t
            elif self.quant_mode == QuantMode.SERVE:
                # TODO(lew): Optionally compare stored and served value.
                pass
            else:
                assert False, 'Unknown quant mode.'
            return None

QDotGeneral

Bases: Module

A layer that can be injected into flax.nn.Dense, etc.

Source code in src/fjformer/bits/q_flax.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class QDotGeneral(nn.Module):
    """A layer that can be injected into flax.nn.Dense, etc."""

    cfg: Optional[config.DotGeneral] = None
    prng_name: Optional[str] = 'params'

    # TODO(lew): split out separate class for each side.
    lhs_quant_mode: QuantMode = QuantMode.TRAIN
    lhs_init: nn.initializers.Initializer = jnp.zeros
    lhs_scale_init: nn.initializers.Initializer = jnp.zeros
    lhs_var_name: str = 'qlhs'

    rhs_quant_mode: QuantMode = QuantMode.TRAIN
    rhs_init: nn.initializers.Initializer = jnp.zeros
    rhs_scale_init: nn.initializers.Initializer = jnp.zeros
    rhs_var_name: str = 'qrhs'

    # If you want use 'params' make sure that there is another mechanism to hide
    # these variables from the optimizer.
    quant_collection: str = 'aqt'

    def make_aqt_dg(
            self,
            lhs_shape,
            rhs_shape,
            dimension_numbers: tuple[Iterable[int], Iterable[int]],
    ):
        lhs_scale_shape = list(lhs_shape)
        rhs_scale_shape = list(rhs_shape)
        (contr, _) = dimension_numbers
        for li, ri in zip(*contr):
            lhs_scale_shape[li] = 1
            rhs_scale_shape[ri] = 1
        lhs_scale = q_dot_general._lhs_scale_transpose(  # pylint: disable=protected-access
            jnp.zeros(lhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape
        )
        assert lhs_scale is not None
        lhs_scale_shape = lhs_scale.shape
        rhs_scale = q_dot_general._rhs_scale_transpose(  # pylint: disable=protected-access
            jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape
        )
        assert rhs_scale is not None
        rhs_scale_shape = rhs_scale.shape

        cfg = copy.deepcopy(self.cfg)
        if cfg is not None:
            rhs_qm = self.rhs_quant_mode
            lhs_qm = self.lhs_quant_mode

            msg = 'The only function that is setting preprocess can be QQuantized.'
            assert cfg.fwd.rhs.preprocess is None, msg
            assert cfg.fwd.lhs.preprocess is None, msg
            cfg.fwd.lhs.preprocess = Freezer(
                name=self.lhs_var_name,
                quant_mode=lhs_qm,
                q_shape=lhs_shape,
                q_init=self.lhs_init,
                s_shape=lhs_scale_shape,
                s_init=self.lhs_scale_init,
                quant_collection=self.quant_collection,
            )
            cfg.fwd.rhs.preprocess = Freezer(
                name=self.rhs_var_name,
                quant_mode=rhs_qm,
                q_shape=rhs_shape,
                q_init=self.rhs_init,
                s_shape=rhs_scale_shape,
                s_init=self.rhs_scale_init,
                quant_collection=self.quant_collection,
            )
        key = self.make_rng(self.prng_name) if self.prng_name is not None else None
        context = q_dot_general.Context(key=key, train_step=None)
        aqt_dg = q_dot_general.make_dot_general(cfg)
        aqt_dg = functools.partial(aqt_dg, context=context)
        return aqt_dg

    @nn.compact
    def __call__(
            self,
            lhs,
            rhs,
            dimension_numbers,
            precision,
            preferred_element_type=None,
    ):
        aqt_dg = self.make_aqt_dg(lhs.shape, rhs.shape, dimension_numbers)
        return aqt_dg(
            lhs,
            rhs,
            dimension_numbers,
            precision,
            preferred_element_type=preferred_element_type,
        )

QEinsum

Bases: PyTreeNode

Quantized Einsum class for model injection.

Source code in src/fjformer/bits/q_flax.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class QEinsum(flax.struct.PyTreeNode):
    """Quantized Einsum class for model injection."""

    cfg: Optional[config.DotGeneral] = None
    prng_name: Optional[str] = 'params'

    # TODO(lew): split out separate class for each side.
    lhs_quant_mode: QuantMode = QuantMode.TRAIN
    lhs_init: nn.initializers.Initializer = jnp.zeros
    lhs_scale_init: nn.initializers.Initializer = jnp.zeros
    lhs_var_name: str = 'qlhs'

    rhs_quant_mode: QuantMode = QuantMode.TRAIN
    rhs_init: nn.initializers.Initializer = jnp.zeros
    rhs_scale_init: nn.initializers.Initializer = jnp.zeros
    rhs_var_name: str = 'qrhs'

    # If you want use 'params' make sure that there is another mechanism to hide
    # these variables from the optimizer.
    quant_collection: str = 'aqt'

    def __call__(self, eqn, lhs_g, rhs_g):
        def einsum(lhs_l, rhs_l, dg=jax.lax.dot_general):
            operands, contractions = lax_numpy._default_poly_einsum_handler(  # pylint: disable=protected-access
                eqn, lhs_l, rhs_l, einsum_call=True, use_blas=True, optimize='optimal'
            )
            contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
            return jax.named_call(lax_numpy._einsum, name=eqn)(  # pylint: disable=protected-access
                operands,
                contractions,
                precision=None,
                preferred_element_type=None,
                _dot_general=dg,
            )

        # yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general
        a = jax.make_jaxpr(einsum)(lhs_g, rhs_g)
        [lhs_g_id, rhs_g_id] = a.eqns[0].invars
        [lhs_l_id, rhs_l_id] = a.jaxpr.invars
        not_swap = lhs_g_id == lhs_l_id and rhs_g_id == rhs_l_id
        yes_swap = lhs_g_id == rhs_l_id and rhs_g_id == lhs_l_id
        assert not_swap != yes_swap

        cfg = copy.deepcopy(self.cfg)
        prng_name = self.prng_name

        lhs_quant_mode = self.lhs_quant_mode
        lhs_init = self.lhs_init
        lhs_scale_init = self.lhs_scale_init
        lhs_var_name = self.lhs_var_name

        rhs_quant_mode = self.rhs_quant_mode
        rhs_init = self.rhs_init
        rhs_scale_init = self.rhs_scale_init
        rhs_var_name = self.rhs_var_name

        quant_collection = self.quant_collection

        if yes_swap:
            if cfg is not None:
                cfg.fwd.lhs, cfg.fwd.rhs = cfg.fwd.rhs, cfg.fwd.lhs
                cfg.dlhs, cfg.drhs = cfg.drhs, cfg.dlhs
            lhs_quant_mode, rhs_quant_mode = rhs_quant_mode, lhs_quant_mode
            lhs_init, rhs_init = rhs_init, lhs_init
            lhs_scale_init, rhs_scale_init = rhs_scale_init, lhs_scale_init
            lhs_var_name, rhs_var_name = rhs_var_name, lhs_var_name

        aqt_dg = QDotGeneral(
            cfg=cfg,
            prng_name=prng_name,
            lhs_quant_mode=lhs_quant_mode,
            lhs_init=lhs_init,
            lhs_scale_init=lhs_scale_init,
            lhs_var_name=lhs_var_name,
            rhs_quant_mode=rhs_quant_mode,
            rhs_init=rhs_init,
            rhs_scale_init=rhs_scale_init,
            rhs_var_name=rhs_var_name,
            quant_collection=quant_collection,
        )
        return einsum(lhs_g, rhs_g, aqt_dg)

config_v4(*, fwd_bits=8, dlhs_bits=8, drhs_bits=None, use_dummy_static_bound=False, rng_type='jax.uniform', dlhs_local_aqt=None, drhs_local_aqt=None, fwd_accumulator_dtype=jnp.int32, dlhs_accumulator_dtype=jnp.int32, drhs_accumulator_dtype=None)

The config_v4 function is a helper function that creates a DotGeneral config object. It takes in the following arguments: - fwd_bits: The number of bits to use for forward pass quantization. If None, no quantization will be used. Defaults to 8 bits. - dlhs_bits: The number of bits to use for left hand side gradient quantization (i.e., the weights). If None, no quantization will be used. Defaults to 8 bits.. - drhs_bits: The number of bits to use for right hand side gradient quanitzation

Parameters:

Name Type Description Default
*

Indicate that the function accepts a variable number of arguments

required
fwd_bits Optional[int]

Optional[int]: Set the number of bits for the forward pass

8
dlhs_bits Optional[int]

Optional[int]: Set the number of bits used for quantization

8
drhs_bits Optional[int]

Optional[int]: Set the number of bits for the right hand side

None
use_dummy_static_bound bool

bool: Set the static bound to 1

False
rng_type str

str: Set the type of random number generator

'jax.uniform'
dlhs_local_aqt Optional[LocalQ]

Optional[config.LocalQ]: Set the local quantization parameters for the lhs gradient

None
drhs_local_aqt Optional[LocalQ]

Optional[config.LocalQ]: Set the local quantization parameters for the drhs tensor

None
fwd_accumulator_dtype ...

...: Set the accumulator dtype for forward pass

int32
dlhs_accumulator_dtype ...

...: Determine the dtype of the accumulator in

int32
drhs_accumulator_dtype ...

...: Determine the dtype of the accumulator used in q_dot_general

None

Determine the number of bits used for quantization

required

Returns:

Type Description
DotGeneral

A config

Source code in src/fjformer/bits/q_flax.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def config_v4(
        *,
        fwd_bits: Optional[int] = 8,
        dlhs_bits: Optional[int] = 8,
        drhs_bits: Optional[int] = None,
        # The dummy static bound flag is for performance benchmarking.
        use_dummy_static_bound: bool = False,
        rng_type: str = 'jax.uniform',  # 'custom-1'
        dlhs_local_aqt: Optional[config.LocalQ] = None,
        drhs_local_aqt: Optional[config.LocalQ] = None,
        fwd_accumulator_dtype: ... = jnp.int32,
        dlhs_accumulator_dtype: ... = jnp.int32,
        drhs_accumulator_dtype: ... = None,
) -> config.DotGeneral:
    """
    The config_v4 function is a helper function that creates a DotGeneral config
    object. It takes in the following arguments:
    - fwd_bits: The number of bits to use for forward pass quantization. If None, no quantization will be used.
    Defaults to 8 bits.
    - dlhs_bits: The number of bits to use for left hand side gradient quantization (i.e., the weights). If None,
     no quantization will be used. Defaults to 8 bits..
    - drhs_bits: The number of bits to use for right hand side gradient quanitzation

    :param *: Indicate that the function accepts a variable number of arguments
    :param fwd_bits: Optional[int]: Set the number of bits for the forward pass
    :param dlhs_bits: Optional[int]: Set the number of bits used for quantization
    :param drhs_bits: Optional[int]: Set the number of bits for the right hand side
    :param use_dummy_static_bound: bool: Set the static bound to 1
    :param rng_type: str: Set the type of random number generator
    :param dlhs_local_aqt: Optional[config.LocalQ]: Set the local quantization parameters for the lhs gradient
    :param drhs_local_aqt: Optional[config.LocalQ]: Set the local quantization parameters for the drhs tensor
    :param fwd_accumulator_dtype: ...: Set the accumulator dtype for forward pass
    :param dlhs_accumulator_dtype: ...: Determine the dtype of the accumulator in
    :param drhs_accumulator_dtype: ...: Determine the dtype of the accumulator used in q_dot_general
    :param : Determine the number of bits used for quantization
    :return: A config
    """

    def tensor_config(bits: Optional[int]) -> config.Tensor:
        assert bits is None or bits >= 2, 'Need at least 2 bits.'
        if bits is None:
            numerics = no_numerics.NoNumerics()
        else:
            numerics = int_numerics.IntNumerics(
                bits=bits,
                preserve_zero=True,
                preserve_max_val=False,
                clip=True,
                round=True,
                noise_fn=None,
                clip_gradient=False,  # Can be False when using abs-max scaling.
                dtype=jnp.int8 if 2 <= bits <= 8 else None,
            )

        return config.Tensor(
            numerics=numerics,
            calib_shared_axes=None,
            scale_stop_grad=True,
            calibration=calibration.AbsMaxCalibration(),
            po2_scale=False,
            use_fake_quant=False,
            # dtype_x=dtype,
            use_fwd_quant=None,
            preprocess=None,
        )

    def dg_raw_config(lhs_bits, rhs_bits, local_aqt=None) -> config.DotGeneralRaw:
        lhs_cfg = tensor_config(lhs_bits)
        rhs_cfg = tensor_config(rhs_bits)
        if (
                True  # Just to format lines below
                and lhs_bits is not None
                and rhs_bits is not None
                and lhs_bits <= 8
                and rhs_bits <= 8
        ):
            dg_accumulator_dtype = jnp.int32
        else:
            # None determines the dtype on the fly in q_dot_general
            dg_accumulator_dtype = None

        return config.DotGeneralRaw(
            lhs=lhs_cfg,
            rhs=rhs_cfg,
            dg_accumulator_dtype=dg_accumulator_dtype,
            local_aqt=local_aqt,
        )

    cfg = config.DotGeneral(
        fwd=dg_raw_config(fwd_bits, fwd_bits),
        dlhs=dg_raw_config(dlhs_bits, dlhs_bits, local_aqt=dlhs_local_aqt),
        drhs=dg_raw_config(drhs_bits, drhs_bits, local_aqt=drhs_local_aqt),
    )

    cfg.dlhs.rhs.use_fwd_quant = False
    cfg.drhs.rhs.use_fwd_quant = False

    # Typically we have (but I don't know if it is guraranteed):
    # - vjp_lhs_stochastic_rounding is referring to the gradient and
    # - vjp_rhs_stochastic_rounding is referring to the activations/weights.
    config.set_stochastic_rounding(
        cfg,
        vjp_lhs_stochastic_rounding=True,
        vjp_rhs_stochastic_rounding=False,
        implementation=rng_type,
    )

    if use_dummy_static_bound:
        config.set_static_bound(cfg, 1.0)

    config.set_accumulator_dtype(
        cfg,
        fwd_dtype=fwd_accumulator_dtype,
        dlhs_dtype=dlhs_accumulator_dtype,
        drhs_dtype=drhs_accumulator_dtype,
    )

    return cfg