Skip to content

pallas_operations.efficient_attention.efficient_attention

efficient_attention(query, key, value, bias=None, deterministic=True, dropout_rng=None, attention_drop_rate=0.0, causal=True, query_chunk_size=1024, key_chunk_size=1024, dtype=jnp.float32, policy=jax.checkpoint_policies.nothing_saveable(), precision=None, float32_logits=True, prevent_cse=True)

Parameters:

Name Type Description Default
query Array

Array Shape [batch,Q Sequence length,num attention heads, head dims]

required
key Array

Array Shape [batch,KV Sequence length,num KV attention heads, head dims]

required
value Array

Array Shape [batch,KV Sequence length,num KV attention heads, head dims]

required
bias Array

Bias To be added

None
deterministic bool

bool (whenever use dropout or no)

True
dropout_rng PRNGKey

RNG Dropout

None
attention_drop_rate float
0.0
causal bool

Is Decoder or Causal

True
query_chunk_size int

Chunk size used for query

1024
key_chunk_size int

Chunk size used for key

1024
dtype ArrayDType

DataType

float32
policy

Gradient Checkpoint Policy

nothing_saveable()
precision

PrecisionLike

None
float32_logits bool
True
prevent_cse bool
True

Returns:

Type Description
Source code in src/fjformer/pallas_operations/efficient_attention/efficient_attention.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def efficient_attention(
        query: chex.Array,
        key: chex.Array,
        value: chex.Array,
        bias: chex.Array = None,
        deterministic: bool = True,
        dropout_rng: chex.PRNGKey = None,
        attention_drop_rate: float = 0.0,
        causal: bool = True,
        query_chunk_size: int = 1024,
        key_chunk_size: int = 1024,
        dtype: chex.ArrayDType = jnp.float32,
        policy=jax.checkpoint_policies.nothing_saveable(),
        precision=None,
        float32_logits: bool = True,
        prevent_cse: bool = True,
):
    """

    :param query: Array Shape [batch,Q Sequence length,num attention heads, head dims]
    :param key: Array Shape [batch,KV Sequence length,num KV attention heads, head dims]
    :param value: Array Shape [batch,KV Sequence length,num KV attention heads, head dims]
    :param bias: Bias To be added
    :param deterministic: bool (whenever use dropout or no)
    :param dropout_rng: RNG Dropout
    :param attention_drop_rate:
    :param causal: Is Decoder or Causal
    :param query_chunk_size: Chunk size used for query
    :param key_chunk_size: Chunk size used for key
    :param dtype: DataType
    :param policy: Gradient Checkpoint Policy
    :param precision: PrecisionLike
    :param float32_logits:
    :param prevent_cse:
    :return:
    """
    query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
    if float32_logits:
        query = query.astype(jnp.float32)
        key = key.astype(jnp.float32)

    batch, q_len, num_heads, dim_per_head = query.shape
    batch, kv_len, kv_heads, dim_per_head = key.shape
    batch, kv_len, kv_heads, dim_per_head = value.shape

    num_q = q_len // query_chunk_size
    num_kv = kv_len // key_chunk_size
    query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
    key = key.reshape((batch, num_kv, key_chunk_size, kv_heads, dim_per_head))
    value = value.reshape((batch, num_kv, key_chunk_size, kv_heads, dim_per_head))

    query = jnp.moveaxis(query, 1, 0)
    key = jnp.moveaxis(key, 1, 0)
    value = jnp.moveaxis(value, 1, 0)

    if bias is not None:
        for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
            assert bias_dim == 1 or bias_dim == broadcast_dim
    if not deterministic and attention_drop_rate > 0.0:
        attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
        attn_dropout = jax.random.bernoulli(attn_dropout_rng, attention_drop_rate, (batch, num_heads, q_len, kv_len))
    else:
        attn_dropout = None

    _chunk_bias_fn = functools.partial(
        _chunk_attention_bias,
        query_chunk_size, key_chunk_size, bias, deterministic,
        attn_dropout, attention_drop_rate, causal, dtype)

    def scan_attention(args):
        query_chunk, query_chunk_idx = args

        @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
        def scan_kv_block(carry, args):
            key_chunk, value_chunk, key_chunk_idx = args
            (numerator, denominator, prev_max_score) = carry
            attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
            bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
            bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
            attn_weights = attn_weights + bias_chunk

            max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
            max_score = jnp.maximum(prev_max_score, max_score)
            max_score = jax.lax.stop_gradient(max_score)
            exp_weights = jnp.exp(attn_weights - max_score)
            exp_values = jnp.einsum(
                'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
            )
            correction = jnp.exp(prev_max_score - max_score)
            numerator = numerator * correction + exp_values
            denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
            return Carry(numerator, denominator, max_score), None

        def skip_upper_half(carry, args):
            key_chunk, value_chunk, key_chunk_idx = args
            skip_block = jnp.array(False)
            if causal:
                skip_block = query_chunk_idx < key_chunk_idx
            return jax.lax.cond(
                skip_block,
                lambda carry, args: (carry, None),
                scan_kv_block,
                carry,
                args,
            )

        init_carry = Carry(
            jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
            jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
            (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
        )
        (numerator, denominator, max_score), _ = lax.scan(
            skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
        )
        outputs = (numerator / denominator).astype(dtype)
        return outputs

    _, res = lax.scan(
        lambda _, x: ((), scan_attention(x)),
        (), xs=(query, jnp.arange(0, num_q))
    )
    res = rearrange(res, 'n b c h d -> b (n c) h d')
    return res