Skip to content

bits.q_dot_general

Quantized dot_general.

TensorRes

All the things we pass from the forward pass to the backward pass.

Source code in src/fjformer/bits/q_dot_general.py
125
126
127
128
129
@flax.struct.dataclass
class TensorRes:
    """All the things we pass from the forward pass to the backward pass."""
    mt: MultiTensor
    quant_grad: Union[Callable[[jnp.ndarray], tuple[jnp.ndarray]], None]

make_dot_general(cfg)

The make_dot_general function is a wrapper around the dot_general function. It takes in two QTensors, lhs and rhs, and returns a QTensor out. The make_dot_general function also handles preprocessing of the inputs to dot_general (lhs and rhs) and postprocessing of the output from dot_general (out). The pre-/post-processing steps are:

Parameters:

Name Type Description Default
cfg Optional[DotGeneral]

Optional[config.DotGeneral]: Specify the configuration of the dot_general operation

required

Returns:

Type Description

A function that returns a function

Source code in src/fjformer/bits/q_dot_general.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def make_dot_general(cfg: Optional[config.DotGeneral]):

    """
    The make_dot_general function is a wrapper around the dot_general function.
    It takes in two QTensors, lhs and rhs, and returns a QTensor out.
    The make_dot_general function also handles preprocessing of the inputs to dot_general (lhs and rhs)
    and postprocessing of the output from dot_general (out).  The pre-/post-processing steps are:

    :param cfg: Optional[config.DotGeneral]: Specify the configuration of the dot_general operation
    :return: A function that returns a function
    """
    if cfg is None:
        def ret_lax_dg(
                lhs,
                rhs,
                dimension_numbers,
                precision=None,
                preferred_element_type=None,
                *,
                context=Context(key=None, train_step=None),
        ):

            """
            The ret_lax_dg function is a wrapper for the jax.lax.dot_general function,
            which performs a general matrix multiplication of two arrays with batch dimensions
            and/or transpositions applied to either or both inputs. The ret_lax_dg function
            is used in the implementation of the dot operation in this module.

            :param lhs: Specify the left-hand side of the dot product
            :param rhs: Specify the right-hand side of the matrix multiplication
            :param dimension_numbers: Specify the dimensions of the operands
            :param precision: Specify the precision of the computation
            :param preferred_element_type: Specify the type of element that should be used to store the result
            :param *: Indicate that all the following parameters are keyword only
            :param context: Pass in the context of the computation
            :param : Specify the dimension numbers of the dot product
            :return: The result of the dot_general operation
            """
            del context
            return jax.lax.dot_general(
                lhs, rhs, dimension_numbers, precision, preferred_element_type
            )

        return ret_lax_dg

    dg = _dot_general_raw_attach_gradient(
        fwd_dot_general_raw=_make_dot_general_raw(cfg.fwd),
        dlhs_dot_general_raw=_make_dot_general_raw(cfg.dlhs),
        drhs_dot_general_raw=_make_dot_general_raw(cfg.drhs),
    )

    def ret_dg(
            lhs,
            rhs,
            dimension_numbers,
            precision=None,
            preferred_element_type=None,
            *,
            context=Context(key=None, train_step=None),
    ):

        """
        The ret_dg function is a wrapper around the dg function.
        It takes in two QTensors, lhs and rhs, and returns a QTensor out.
        The ret_dg function also handles preprocessing of the inputs to dg (lhs and rhs)
        and postprocessing of the output from dg (out).  The pre-/post-processing steps are:

        :param lhs: Pass the left hand side of the matrix multiplication
        :param rhs: Pass on the right hand side of the matrix multiplication
        :param dimension_numbers: Specify the contraction pattern
        :param precision: Specify the precision of the output
        :param preferred_element_type: Specify the dtype of the output
        :param *: Indicate that the argument is a keyword-only
        :param context: Pass the context to the dg function
        :return: A function that returns a function
        """
        del preferred_element_type
        assert (
                precision is None
        ), f'Precision {precision} requested together with quantization.'

        msg = 'AQT is not yet optimized to accept quantized types directly. '
        msg += f'lhs.dtype: {lhs.dtype}, rhs.dtype: {rhs.dtype}'
        assert lhs.dtype in [jnp.bfloat16, jnp.float32, jnp.float16], msg
        assert rhs.dtype in [jnp.bfloat16, jnp.float32, jnp.float16], msg
        # TODO(lew): Refactor Have a flax class with get and set.
        # TODO(lew): Have a function to handle lhs and rhs uniformly.
        lhs_qt = None
        if cfg.fwd.lhs.preprocess is not None:
            # lhs_q is quantized dtype.
            # we are breaking the invariant that QTensor has a float qvalue
            # But it will just be cast again to the same type.
            lhs_qt = cfg.fwd.lhs.preprocess(None)
        rhs_qt = None
        if cfg.fwd.rhs.preprocess is not None:
            rhs_qt = cfg.fwd.rhs.preprocess(None)

        out, (out_lhs_qt, out_rhs_qt) = dg(
            lhs=lhs,
            rhs=rhs,
            lhs_qt=lhs_qt,
            rhs_qt=rhs_qt,
            dimension_numbers=dimension_numbers,
            context=context,
        )

        if cfg.fwd.lhs.preprocess is not None:
            lhs_dtype = cfg.fwd.lhs.numerics.get_dtype()
            out_lhs_qt = QTensor(
                out_lhs_qt.qvalue.astype(lhs_dtype), out_lhs_qt.qvalue_scale_t
            )
            none = cfg.fwd.lhs.preprocess(out_lhs_qt)
            assert none is None
        if cfg.fwd.rhs.preprocess is not None:
            rhs_dtype = cfg.fwd.rhs.numerics.get_dtype()
            out_rhs_qt = QTensor(
                out_rhs_qt.qvalue.astype(rhs_dtype), out_rhs_qt.qvalue_scale_t
            )
            none = cfg.fwd.rhs.preprocess(out_rhs_qt)
            assert none is None

        return out

    return ret_dg