Skip to content

bits.config

Configuration dataclasses.

DotGeneral dataclass

Configuration of quantization of dot_general and its gradients.

Source code in src/fjformer/bits/config.py
101
102
103
104
105
106
107
108
109
110
111
@dataclasses.dataclass(slots=True)
class DotGeneral:
    """Configuration of quantization of dot_general and its gradients."""

    fwd: DotGeneralRaw
    dlhs: DotGeneralRaw
    drhs: DotGeneralRaw

    @classmethod
    def make(cls, *args, **kwargs) -> 'DotGeneral':
        return dot_general_make(*args, **kwargs)

DotGeneralRaw dataclass

Configuration of quantization of one dot_general without gradient.

Source code in src/fjformer/bits/config.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@dataclasses.dataclass(slots=True)
class DotGeneralRaw:
    """Configuration of quantization of one dot_general without gradient."""

    lhs: Tensor
    rhs: Tensor
    dg_accumulator_dtype: Optional[DType]
    local_aqt: Optional[LocalQ]

    @classmethod
    def make(cls, *args, **kwargs) -> 'DotGeneralRaw':
        """
        The make function is a factory function that creates an instance of the DotGeneralRaw class.

        :param cls: Create a new instance of the class
        :param args: Send a non-keyworded variable length argument list to the function
        :param kwargs: Pass a variable number of keyword arguments to the function
        :return: A dotgeneralraw object
        """
        return dot_general_raw_make(*args, **kwargs)

    @classmethod
    def make_conv_general_dilated(cls, *args, **kwargs) -> 'DotGeneralRaw':
        """
        The make_conv_general_dilated function is a wrapper for the conv_general_dilated_make function.
        It allows us to use the make function in our DotGeneralRaw class, which we can then use as a
        component of our network. The make function takes in arguments that are used by the
        conv_general_dilated_make function to create an instance of DotGeneralRaw.

        :param cls: Create a new instance of the class
        :param args: Send a non-keyworded variable length argument list to the function
        :param kwargs: Pass a variable number of keyword arguments to a function
        :return: A dotgeneralraw object
        """
        return conv_general_dilated_make(*args, **kwargs)

make(*args, **kwargs) classmethod

The make function is a factory function that creates an instance of the DotGeneralRaw class.

Parameters:

Name Type Description Default
cls

Create a new instance of the class

required
args

Send a non-keyworded variable length argument list to the function

()
kwargs

Pass a variable number of keyword arguments to the function

{}

Returns:

Type Description
DotGeneralRaw

A dotgeneralraw object

Source code in src/fjformer/bits/config.py
73
74
75
76
77
78
79
80
81
82
83
@classmethod
def make(cls, *args, **kwargs) -> 'DotGeneralRaw':
    """
    The make function is a factory function that creates an instance of the DotGeneralRaw class.

    :param cls: Create a new instance of the class
    :param args: Send a non-keyworded variable length argument list to the function
    :param kwargs: Pass a variable number of keyword arguments to the function
    :return: A dotgeneralraw object
    """
    return dot_general_raw_make(*args, **kwargs)

make_conv_general_dilated(*args, **kwargs) classmethod

The make_conv_general_dilated function is a wrapper for the conv_general_dilated_make function. It allows us to use the make function in our DotGeneralRaw class, which we can then use as a component of our network. The make function takes in arguments that are used by the conv_general_dilated_make function to create an instance of DotGeneralRaw.

Parameters:

Name Type Description Default
cls

Create a new instance of the class

required
args

Send a non-keyworded variable length argument list to the function

()
kwargs

Pass a variable number of keyword arguments to a function

{}

Returns:

Type Description
DotGeneralRaw

A dotgeneralraw object

Source code in src/fjformer/bits/config.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@classmethod
def make_conv_general_dilated(cls, *args, **kwargs) -> 'DotGeneralRaw':
    """
    The make_conv_general_dilated function is a wrapper for the conv_general_dilated_make function.
    It allows us to use the make function in our DotGeneralRaw class, which we can then use as a
    component of our network. The make function takes in arguments that are used by the
    conv_general_dilated_make function to create an instance of DotGeneralRaw.

    :param cls: Create a new instance of the class
    :param args: Send a non-keyworded variable length argument list to the function
    :param kwargs: Pass a variable number of keyword arguments to a function
    :return: A dotgeneralraw object
    """
    return conv_general_dilated_make(*args, **kwargs)

Tensor dataclass

Configuration of quantization of one tensor or one side of tensor op.

Source code in src/fjformer/bits/config.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@dataclasses.dataclass(slots=True)
class Tensor:
    """Configuration of quantization of one tensor or one side of tensor op."""

    numerics: numerics.QNumerics
    calib_shared_axes: Optional[list[int]]
    scale_stop_grad: bool
    # noise+clip+round
    # We apply gradient of clip_and_round in bwd pass.
    calibration: calibration.Calibration
    # Round up the calibration to power of 2 (po2).
    po2_scale: bool
    use_fake_quant: bool
    # Controls at what value of input tensor should be used.
    # Setting it to True, but not quantizing fwd pass will assert-fail.
    use_fwd_quant: Optional[bool]
    # Operations for retrieving or storing quantized tensors and their scales
    # TODO(yichizh): Factor out auxilliary dataclasses into a separate file.
    # The following dtype Any should be q_dot_general.QTensor but that triggers
    # recursive importing
    preprocess: Optional[Callable[[Optional[Any]], Optional[Any]]]

    @classmethod
    def make(cls, *args, **kwargs) -> 'Tensor':
        return tensor_make(*args, **kwargs)

config_v3(*, fwd_bits=8, dlhs_bits=8, drhs_bits=None, use_dummy_static_bound=False, rng_type='jax.uniform', dlhs_local_aqt=None, drhs_local_aqt=None, fwd_accumulator_dtype=jnp.int32, dlhs_accumulator_dtype=jnp.int32, drhs_accumulator_dtype=None)

The config_v3 function is a helper function that configures the DotGeneral object. It takes in keyword arguments and returns a configured DotGeneral object. The following are the keyword arguments:

Parameters:

Name Type Description Default
*

Indicate that all the following parameters are keyword-only

required
fwd_bits Optional[int]

Optional[int]: Set the number of bits used for forward pass

8
dlhs_bits Optional[int]

Optional[int]: Set the number of bits for the

8
drhs_bits Optional[int]

Optional[int]: Specify the number of bits

None
use_dummy_static_bound bool

bool: Set the static bound to 1

False
rng_type str

str: Specify the random number generator

'jax.uniform'
dlhs_local_aqt Optional[LocalQ]

Optional[LocalQ]: Set the local quantization of the dlhs

None
drhs_local_aqt Optional[LocalQ]

Optional[LocalQ]: Set the local quantization

None
fwd_accumulator_dtype ...

...: Specify the accumulator dtype for the forward pass

int32
dlhs_accumulator_dtype ...

...: Specify the accumulator dtype for the gradient

int32
drhs_accumulator_dtype ...

...: Specify the data type of the accumulator in drhs

None

Specify the number of bits used for quantization

required

Returns:

Type Description
DotGeneral

A dotgeneral object

Source code in src/fjformer/bits/config.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def config_v3(
        *,
        fwd_bits: Optional[int] = 8,
        dlhs_bits: Optional[int] = 8,
        drhs_bits: Optional[int] = None,
        use_dummy_static_bound: bool = False,
        rng_type: str = 'jax.uniform',  # 'custom-1'
        dlhs_local_aqt: Optional[LocalQ] = None,
        drhs_local_aqt: Optional[LocalQ] = None,
        fwd_accumulator_dtype: ... = jnp.int32,
        dlhs_accumulator_dtype: ... = jnp.int32,
        drhs_accumulator_dtype: ... = None,
) -> DotGeneral:
    """
    The config_v3 function is a helper function that configures the DotGeneral
    object. It takes in keyword arguments and returns a configured DotGeneral object.
    The following are the keyword arguments:

    :param *: Indicate that all the following parameters are keyword-only
    :param fwd_bits: Optional[int]: Set the number of bits used for forward pass
    :param dlhs_bits: Optional[int]: Set the number of bits for the
    :param drhs_bits: Optional[int]: Specify the number of bits
    :param use_dummy_static_bound: bool: Set the static bound to 1
    :param rng_type: str: Specify the random number generator
    :param dlhs_local_aqt: Optional[LocalQ]: Set the local quantization of the dlhs
    :param drhs_local_aqt: Optional[LocalQ]: Set the local quantization
    :param fwd_accumulator_dtype: ...: Specify the accumulator dtype for the forward pass
    :param dlhs_accumulator_dtype: ...: Specify the accumulator dtype for the gradient
    :param drhs_accumulator_dtype: ...: Specify the data type of the accumulator in drhs
    :param : Specify the number of bits used for quantization
    :return: A dotgeneral object
    """
    fwd = dot_general_raw_make(fwd_bits, fwd_bits)
    dlhs = dot_general_raw_make(dlhs_bits, dlhs_bits, local_aqt=dlhs_local_aqt)
    drhs = dot_general_raw_make(drhs_bits, drhs_bits, local_aqt=drhs_local_aqt)
    cfg = DotGeneral(fwd=fwd, dlhs=dlhs, drhs=drhs)

    cfg.dlhs.rhs.use_fwd_quant = False
    cfg.drhs.rhs.use_fwd_quant = False

    # Typically we have (but I don't know if it is guraranteed):
    # - vjp_lhs_stochastic_rounding is referring to the gradient and
    # - vjp_rhs_stochastic_rounding is referring to the activations/weights.
    set_stochastic_rounding(
        cfg,
        vjp_lhs_stochastic_rounding=True,
        vjp_rhs_stochastic_rounding=False,
        implementation=rng_type,
    )

    if use_dummy_static_bound:
        set_static_bound(cfg, 1.0)

    set_accumulator_dtype(
        cfg,
        fwd_dtype=fwd_accumulator_dtype,
        dlhs_dtype=dlhs_accumulator_dtype,
        drhs_dtype=drhs_accumulator_dtype,
    )
    return cfg

conv_general_dilated_make(spatial_dimensions=2, lhs_bits=None, rhs_bits=None)

Create quantization config conv_general_dilated.

Source code in src/fjformer/bits/config.py
270
271
272
273
274
275
276
277
278
279
280
281
282
def conv_general_dilated_make(
        spatial_dimensions=2,
        lhs_bits: Optional[int] = None,
        rhs_bits: Optional[int] = None,
) -> 'DotGeneralRaw':
    """Create quantization config conv_general_dilated."""
    config = dot_general_raw_make(lhs_bits, rhs_bits)
    # Hardcoding flax assumptions.
    if config.lhs:
        config.lhs.calib_shared_axes = list(range(1, spatial_dimensions + 2))
    if config.rhs:
        config.rhs.calib_shared_axes = list(range(0, spatial_dimensions + 2 - 1))
    return config

dot_general_make(lhs_bits=None, rhs_bits=None, bwd_bits=None, use_fwd_quant=True, dlhs_local_aqt=None, drhs_local_aqt=None)

Create quantization configs for input matrices to a matmul.

Source code in src/fjformer/bits/config.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def dot_general_make(
        lhs_bits: Optional[int] = None,
        rhs_bits: Optional[int] = None,
        bwd_bits: Optional[int] = None,
        use_fwd_quant: bool = True,
        dlhs_local_aqt=None,
        drhs_local_aqt=None,
) -> 'DotGeneral':
    """Create quantization configs for input matrices to a matmul."""
    fwd = dot_general_raw_make(lhs_bits, rhs_bits)
    dlhs = dot_general_raw_make(bwd_bits, bwd_bits, local_aqt=dlhs_local_aqt)
    drhs = dot_general_raw_make(bwd_bits, bwd_bits, local_aqt=drhs_local_aqt)
    cfg = DotGeneral(fwd=fwd, dlhs=dlhs, drhs=drhs)

    # Surprising: lhs quantization determines what drhs can do.
    if lhs_bits is not None:
        # Only rhs is accepting MultiTensor.
        cfg.drhs.rhs.use_fwd_quant = use_fwd_quant
    if rhs_bits is not None:
        cfg.dlhs.rhs.use_fwd_quant = use_fwd_quant
    return cfg

dot_general_raw_make(lhs_bits=None, rhs_bits=None, local_aqt=None)

The dot_general_raw_make function is a helper function that creates a DotGeneralRaw object.

Parameters:

Name Type Description Default
lhs_bits

Determine the dtype of the lhs tensor

None
rhs_bits

Determine the dtype of the accumulator

None
local_aqt

Determine the type of accumulator used

None

Determine the dtype of the accumulator

required

Returns:

Type Description
DotGeneralRaw

A dotgeneralraw object

Source code in src/fjformer/bits/config.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def dot_general_raw_make(
        lhs_bits=None,
        rhs_bits=None,
        local_aqt=None,
) -> 'DotGeneralRaw':
    """
    The dot_general_raw_make function is a helper function that creates a DotGeneralRaw object.

    :param lhs_bits: Determine the dtype of the lhs tensor
    :param rhs_bits: Determine the dtype of the accumulator
    :param local_aqt: Determine the type of accumulator used
    :param : Determine the dtype of the accumulator
    :return: A dotgeneralraw object
    """
    lhs_cfg = tensor_make(lhs_bits)
    rhs_cfg = tensor_make(rhs_bits)

    # Binary uses 0.5 right now.
    if (
            lhs_bits is not None
            and rhs_bits is not None
            and 2 <= lhs_bits <= 8
            and 2 <= rhs_bits <= 8
    ):
        dg_accumulator_dtype = jnp.int32
    else:
        dg_accumulator_dtype = None

    return DotGeneralRaw(
        lhs=lhs_cfg,
        rhs=rhs_cfg,
        dg_accumulator_dtype=dg_accumulator_dtype,
        local_aqt=local_aqt,
    )

fully_quantized(*, fwd_bits=8, bwd_bits=8, use_fwd_quant=True, use_stochastic_rounding=True, vjp_lhs_stochastic_rounding=None, vjp_rhs_stochastic_rounding=None, use_dummy_static_bound=False, dlhs_local_aqt=None, drhs_local_aqt=None)

The fully_quantized function is a helper function that allows you to quickly configure the dot_general primitive with all of its quantization parameters. It takes in keyword arguments for each of the quantization parameters, and returns a DotGeneral configuration object. The following table shows what each parameter does:

Parameters:

Name Type Description Default
*

Indicate that all the parameters are keyword-only

required
fwd_bits Optional[int]

Optional[int]: Specify the number of bits used for forward quantization

8
bwd_bits Optional[int]

Optional[int]: Set the number of bits used for backpropagation

8
use_fwd_quant bool

bool: Control whether to quantize the

True
use_stochastic_rounding Optional[bool]

Optional[bool]: Enable stochastic rounding

True
vjp_lhs_stochastic_rounding Optional[bool]

Optional[bool]: Ensure that we don't mix

None
vjp_rhs_stochastic_rounding Optional[bool]

Optional[bool]:

None
use_dummy_static_bound bool

bool: Set the static bound to 1

False
dlhs_local_aqt Optional[LocalQ]

Optional[LocalQ]: Specify the quantization scheme for the left-hand side of a matrix multiplication

None
drhs_local_aqt Optional[LocalQ]

Optional[LocalQ]: Specify the quantization scheme for the right hand side of a matrix multiplication

None

Set the number of bits used for forward and backward pass

required

Returns:

Type Description
DotGeneral

A dotgeneral object, which is a

Source code in src/fjformer/bits/config.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def fully_quantized(
        *,
        fwd_bits: Optional[int] = 8,
        bwd_bits: Optional[int] = 8,
        use_fwd_quant: bool = True,
        use_stochastic_rounding: Optional[bool] = True,
        # Typically we have (but it's a caller's responsibility to check):
        # - vjp_lhs_stochastic_rounding is referring to the gradient and
        # - vjp_rhs_stochastic_rounding is referring to the activations/weights.
        vjp_lhs_stochastic_rounding: Optional[bool] = None,
        vjp_rhs_stochastic_rounding: Optional[bool] = None,
        # The dummy static bound flag is temporary, for performance benchmarking.
        use_dummy_static_bound: bool = False,
        dlhs_local_aqt: Optional[LocalQ] = None,
        drhs_local_aqt: Optional[LocalQ] = None,
) -> DotGeneral:
    """
    The fully_quantized function is a helper function that allows you to quickly
    configure the dot_general primitive with all of its quantization parameters.
    It takes in keyword arguments for each of the quantization parameters, and returns
    a DotGeneral configuration object. The following table shows what each parameter does:

    :param *: Indicate that all the parameters are keyword-only
    :param fwd_bits: Optional[int]: Specify the number of bits used for forward quantization
    :param bwd_bits: Optional[int]: Set the number of bits used for backpropagation
    :param use_fwd_quant: bool: Control whether to quantize the
    :param use_stochastic_rounding: Optional[bool]: Enable stochastic rounding
    :param vjp_lhs_stochastic_rounding: Optional[bool]: Ensure that we don't mix
    :param vjp_rhs_stochastic_rounding: Optional[bool]:
    :param use_dummy_static_bound: bool: Set the static bound to 1
    :param dlhs_local_aqt: Optional[LocalQ]: Specify the quantization scheme for the left-hand side of a matrix multiplication
    :param drhs_local_aqt: Optional[LocalQ]: Specify the quantization scheme for the right hand side of a matrix multiplication
    :param : Set the number of bits used for forward and backward pass
    :return: A dotgeneral object, which is a
    """
    cfg = dot_general_make(
        lhs_bits=fwd_bits,
        rhs_bits=fwd_bits,
        bwd_bits=bwd_bits,
        use_fwd_quant=use_fwd_quant,
        dlhs_local_aqt=dlhs_local_aqt,
        drhs_local_aqt=drhs_local_aqt,
    )

    # Stochastic Rounding
    # These 3 variables are used to ensure we don't mix
    # old and new style of SR configuration.
    old_style_sr_config = use_stochastic_rounding is not None
    new_style_sr_config_lhs = vjp_lhs_stochastic_rounding is not None
    new_style_sr_config_rhs = vjp_rhs_stochastic_rounding is not None
    assert new_style_sr_config_lhs == new_style_sr_config_rhs, (
        'if you use new style SR config (vjp_xhs_stochastic_rounding), do pass'
        ' both lhs and rhs explicitely.'
    )
    assert new_style_sr_config_lhs != old_style_sr_config

    true = True  # A crude way to get around g-explicit-bool-comparison warning

    assert not (vjp_lhs_stochastic_rounding and vjp_rhs_stochastic_rounding), (
        'This config is buggy when you set both to True. Contact lew@ or use'
        ' config_v3'
    )

    # By default use jax.uniform for stochastic rounding
    if use_stochastic_rounding == true:
        set_stochastic_rounding(cfg, True, True, 'jax.uniform')

    if vjp_lhs_stochastic_rounding == true:
        set_stochastic_rounding(cfg, True, False, 'jax.uniform')

    if vjp_rhs_stochastic_rounding == true:
        set_stochastic_rounding(cfg, False, True, 'jax.uniform')

    if use_dummy_static_bound:
        set_static_bound(cfg, 1.0)

    return cfg

set_accumulator_dtype(cfg, fwd_dtype, dlhs_dtype, drhs_dtype)

The set_accumulator_dtype function sets the accumulator dtype for each of the three differentiable functions. The accumulator dtype is used to store intermediate results during forward and backward passes. It is also used to store gradients during backward pass. The default value for this parameter is None, which means that it will be set automatically by the library based on other parameters such as input data types and output data type.

Parameters:

Name Type Description Default
cfg DotGeneral

DotGeneral: Set the accumulator dtype for all three

required
fwd_dtype Optional[DType]

Optional[DType]: Set the dtype of the forward pass

required
dlhs_dtype Optional[DType]

Optional[DType]: Set the data type of the left hand side

required
drhs_dtype Optional[DType]

Optional[DType]: Set the data type for the drhs accumulator

required

Set the dtype of the accumulator

required
Source code in src/fjformer/bits/config.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def set_accumulator_dtype(
        cfg: DotGeneral,
        fwd_dtype: Optional[DType],
        dlhs_dtype: Optional[DType],
        drhs_dtype: Optional[DType],
):
    """
    The set_accumulator_dtype function sets the accumulator dtype for each of the three
    differentiable functions.  The accumulator dtype is used to store intermediate results
    during forward and backward passes.  It is also used to store gradients during backward pass.
    The default value for this parameter is None, which means that it will be set automatically by
    the library based on other parameters such as input data types and output data type.

    :param cfg: DotGeneral: Set the accumulator dtype for all three
    :param fwd_dtype: Optional[DType]: Set the dtype of the forward pass
    :param dlhs_dtype: Optional[DType]: Set the data type of the left hand side
    :param drhs_dtype: Optional[DType]: Set the data type for the drhs accumulator
    :param : Set the dtype of the accumulator
    """
    cfg.fwd.dg_accumulator_dtype = fwd_dtype
    cfg.dlhs.dg_accumulator_dtype = dlhs_dtype
    cfg.drhs.dg_accumulator_dtype = drhs_dtype

set_fwd_numerics(cfg, fwd_numerics)

The set_fwd_numerics function sets the numerics of the forward problem.

Parameters:

Name Type Description Default
cfg

Store the configuration of the simulation

required
fwd_numerics QNumerics

numerics.QNumerics: Set the numerical

required

Returns:

Type Description

The configuration object with the numerics for the forward problem set

Source code in src/fjformer/bits/config.py
114
115
116
117
118
119
120
121
122
123
def set_fwd_numerics(cfg, fwd_numerics: numerics.QNumerics):
    """
    The set_fwd_numerics function sets the numerics of the forward problem.

    :param cfg: Store the configuration of the simulation
    :param fwd_numerics: numerics.QNumerics: Set the numerical
    :return: The configuration object with the numerics for the forward problem set
    """
    cfg.fwd.lhs.numerics = fwd_numerics
    cfg.fwd.rhs.numerics = fwd_numerics

set_static_bound(cfg, bound=1.0)

The set_static_bound function sets the calibration of all the forward and backward differentiation operators to a constant value. This is useful for testing purposes, as it allows us to check that our implementation is correct by comparing against known values.

Parameters:

Name Type Description Default
cfg DotGeneral

DotGeneral: Set the bounds for each of the six functions in a dotgeneral object

required
bound float

float: Set the bound of the calibration

1.0
Source code in src/fjformer/bits/config.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def set_static_bound(cfg: DotGeneral, bound: float = 1.0):

    """
    The set_static_bound function sets the calibration of all the forward and backward
    differentiation operators to a constant value. This is useful for testing purposes, as it
    allows us to check that our implementation is correct by comparing against known values.

    :param cfg: DotGeneral: Set the bounds for each of the six functions in a dotgeneral object
    :param bound: float: Set the bound of the calibration
    """
    cfg.fwd.lhs.calibration = calibration.ConstantCalibration(bound)
    cfg.fwd.rhs.calibration = calibration.ConstantCalibration(bound)
    cfg.drhs.lhs.calibration = calibration.ConstantCalibration(bound)
    cfg.drhs.rhs.calibration = calibration.ConstantCalibration(bound)
    cfg.dlhs.lhs.calibration = calibration.ConstantCalibration(bound)
    cfg.dlhs.rhs.calibration = calibration.ConstantCalibration(bound)

set_stochastic_rounding(cfg, vjp_lhs_stochastic_rounding, vjp_rhs_stochastic_rounding, implementation)

Configure stochastic rounding implementation.

Source code in src/fjformer/bits/config.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def set_stochastic_rounding(
        cfg: DotGeneral,
        vjp_lhs_stochastic_rounding: bool,
        vjp_rhs_stochastic_rounding: bool,
        implementation: str,
):
    """Configure stochastic rounding implementation."""
    noise_implementations = {
        'jax.uniform': lambda shape, key: jax.random.uniform(key, shape) - 0.5,
        'custom-1': stochastic_rounding.random_centered_uniform,
    }
    msg = f'{implementation} not supported.'
    assert implementation in noise_implementations.keys(), msg
    noise_fn = noise_implementations[implementation]

    if vjp_lhs_stochastic_rounding:
        cfg.dlhs.lhs.numerics = cfg.dlhs.lhs.numerics.replace(noise_fn=noise_fn)
        cfg.drhs.lhs.numerics = cfg.drhs.lhs.numerics.replace(noise_fn=noise_fn)
    else:
        cfg.dlhs.lhs.numerics = cfg.dlhs.lhs.numerics.replace(noise_fn=None)
        cfg.drhs.lhs.numerics = cfg.drhs.lhs.numerics.replace(noise_fn=None)

    if vjp_rhs_stochastic_rounding:
        cfg.dlhs.rhs.numerics = cfg.dlhs.rhs.numerics.replace(noise_fn=noise_fn)
        cfg.drhs.rhs.numerics = cfg.drhs.rhs.numerics.replace(noise_fn=noise_fn)
    else:
        cfg.dlhs.rhs.numerics = cfg.dlhs.rhs.numerics.replace(noise_fn=None)
        cfg.drhs.rhs.numerics = cfg.drhs.rhs.numerics.replace(noise_fn=None)

tensor_make(bits)

The tensor_make function is a helper function that creates a Tensor object.

Parameters:

Name Type Description Default
bits Optional[int]

Optional[int]: Set the number of bits for quantization

required

Returns:

Type Description
Tensor

A tensor object

Source code in src/fjformer/bits/config.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def tensor_make(bits: Optional[int]) -> 'Tensor':

    """
    The tensor_make function is a helper function that creates a Tensor object.

    :param bits: Optional[int]: Set the number of bits for quantization
    :return: A tensor object
    """
    if bits is None:
        effective_numerics = no_numerics.NoNumerics()
    else:
        pz = False if bits == 1 else True
        dtype = jnp.int8 if 2 <= bits <= 8 and pz else None
        effective_numerics = int_numerics.IntNumerics(
            bits=bits,
            preserve_zero=pz,
            preserve_max_val=False,
            clip=True,
            round=True,
            noise_fn=None,
            clip_gradient=False,  # This can be disabled when using abs-max scaling.
            dtype=dtype,
        )

    return Tensor(
        numerics=effective_numerics,
        calib_shared_axes=None,
        scale_stop_grad=True,
        calibration=calibration.AbsMaxCalibration(),
        po2_scale=False,
        use_fake_quant=False,
        use_fwd_quant=None,
        preprocess=None,
    )