Skip to content

bits.qk

quantize_kv(kv)

Quantize key/values stored in kvcache.

Source code in src/fjformer/bits/qk.py
 7
 8
 9
10
11
def quantize_kv(kv: chex.Array):
    """Quantize key/values stored in kvcache."""
    scale = jnp.max(jnp.abs(kv), axis=-1, keepdims=True)
    value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale)))
    return value, scale

unquantize_kv(value, scale, dtype)

Unquantize key/values stored in kvcache.

Source code in src/fjformer/bits/qk.py
14
15
16
def unquantize_kv(value: chex.Array, scale: chex.Array, dtype: jnp.dtype):
    """Unquantize key/values stored in kvcache."""
    return value.astype(dtype) * scale / MAX_INT8