Skip to content

modules.attention_module

AttentionModule

Source code in src/python/easydel/modules/attention_module.py
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
class AttentionModule:
    def __init__(
            self,
            mesh: Mesh,
            attn_mechanism: Literal[
                "vanilla",
                "flash",
                "splash",
                "ring",
                "cudnn",
                "local_ring",
                "sharded_vanilla",
                "wise_ring",
                "blockwise",
                "pallas_flash"
            ],
            num_attention_heads: int,
            head_dims: int,
            block_k: int = DEFAULT_K_BLOCK,
            block_q: int = DEFAULT_Q_BLOCK,
            block_b: int = DEFAULT_Q_BLOCK,
            block_k_major: int = DEFAULT_K_BLOCK,
            block_q_major_dkv: int = DEFAULT_Q_BLOCK,
            block_k_major_dkv: int = DEFAULT_K_BLOCK,
            block_k_dkv: int = DEFAULT_K_BLOCK,
            block_q_dkv: int = DEFAULT_Q_BLOCK,
            block_k_major_dq: int = DEFAULT_K_BLOCK,
            block_k_dq: int = DEFAULT_K_BLOCK,
            block_q_dq: int = DEFAULT_Q_BLOCK,
            sm_scale: Optional[float] = None,
            query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            generation_query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, "tp", None),
            key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, "sp", None),
            generation_bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
            attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            generation_attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, "tp", None),
            scan_ring_attention: bool = True,
            scan_attention_layers: bool = True,
            attention_dropout: float = 0.0,
            dtype: jnp.dtype = jnp.float32,
            precision: lax.Precision = lax.Precision("fastest"),
            force_float32_tpu: bool = True,
            shard_attention_computation: bool = True,
            use_sharding_constraint: Optional[bool] = False,
            axis_name: str = "sp",
            backward_pass_impl: Literal["triton", "xla"] = "triton"
    ):
        platform = jax.lib.xla_bridge.get_backend().platform
        if sm_scale is None:
            sm_scale = 1 / math.sqrt(head_dims)
        self.platform = platform
        self.attn_mechanism = attn_mechanism
        self.block_k = block_k
        self.block_q = block_q
        self.block_b = block_b
        self.block_k_major = block_k_major
        self.block_q_major_dkv = block_q_major_dkv
        self.block_k_major_dkv = block_k_major_dkv
        self.block_k_dkv = block_k_dkv
        self.block_q_dkv = block_q_dkv
        self.block_k_major_dq = block_k_major_dq
        self.block_k_dq = block_k_dq
        self.block_q_dq = block_q_dq
        self.num_attention_heads = num_attention_heads
        self.head_dims = head_dims
        self.sm_scale = sm_scale
        self.mesh = mesh
        self.query_partition_spec = query_partition_spec
        self.key_partition_spec = key_partition_spec
        self.value_partition_spec = value_partition_spec
        self.bias_partition_spec = bias_partition_spec
        self.attention_partition_spec = attention_partition_spec
        self.attention_dropout = attention_dropout
        self.dtype = dtype
        self.precision = precision
        self.force_float32_tpu = force_float32_tpu
        self.shard_attention_computation = shard_attention_computation
        self.use_sharding_constraint = use_sharding_constraint
        self.scan_ring_attention = scan_ring_attention
        self.scan_attention_layers = scan_attention_layers
        self.generation_query_partition_spec = generation_query_partition_spec
        self.generation_bias_partition_spec = generation_bias_partition_spec
        self.generation_attention_partition_spec = generation_attention_partition_spec
        self.axis_name = axis_name
        self.backward_pass_impl = backward_pass_impl
        if attn_mechanism == "splash" and self.platform != "tpu":
            raise OSError("splash attention is only supported on TPU.")
        if attn_mechanism == "flash" and self.platform != "tpu":
            error_msg = "flash attention is only supported on TPU"
            if self.platform == "gpu":
                error_msg += ", for GPUs flash attention you can use `cudnn`."
            raise OSError(error_msg)
        if attn_mechanism == "cudnn" and self.platform != "gpu":
            raise OSError("flash attention is only supported on GPU.")

    def get_block_size_splash_attn(self, q_seq, k_seq):
        return BlockSizesSplashAttn(
            block_q=min(self.block_q, q_seq),
            block_kv_compute=min(self.block_k, k_seq),
            block_kv=min(self.block_k, k_seq),
            block_q_dkv=min(self.block_q_dkv, q_seq),
            block_kv_dkv=min(self.block_k_dkv, k_seq),
            block_kv_dkv_compute=min(self.block_k_dkv, q_seq),
            block_q_dq=min(self.block_q_dq, q_seq),
            block_kv_dq=min(self.block_k_dq, q_seq),
        )

    def get_block_size_flash_attn(self, q_seq, k_seq):
        return BlockSizesFlashAttn(
            block_q=min(self.block_q, q_seq),
            block_k=min(self.block_k, k_seq),
            block_q_dkv=min(self.block_q_dkv, q_seq),
            block_k_dq=min(self.block_k_dkv, k_seq),
            block_k_dkv=min(self.block_k_dkv, q_seq),
            block_q_dq=min(self.block_q_dq, q_seq),
            block_b=min(self.block_b, 1),
            block_k_major=min(self.block_k_major, q_seq),
            block_k_major_dq=min(self.block_k_major_dq, q_seq),
            block_k_major_dkv=min(self.block_k_major_dkv, q_seq),
            block_q_major_dkv=min(self.block_q_major_dkv, q_seq)
        )

    def get_partition_specs(self, qs) -> Tuple[
        PartitionSpec, PartitionSpec, PartitionSpec, PartitionSpec, PartitionSpec, bool
    ]:
        is_generating = qs == 1
        query_sequence_partition = self.generation_query_partition_spec if is_generating else self.query_partition_spec
        bias_partition_spec = self.generation_bias_partition_spec if is_generating else self.bias_partition_spec
        attention_partition_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec

        return (
            query_sequence_partition,
            self.key_partition_spec,
            self.value_partition_spec,
            bias_partition_spec,
            attention_partition_spec,
            is_generating
        )

    def _check_states(
            self,
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int,
            key_value_sequence_length: int,
    ):
        batch_size = query_states.shape[0]
        assert batch_size == key_states.shape[0] == value_states.shape[0], "Batch Size for q,k,v wont match"
        k_v_req_shape = (
            batch_size,
            key_value_sequence_length,
            self.num_attention_heads,
            self.head_dims
        )
        q_shape = (
            batch_size,
            query_sequence_length,
            self.num_attention_heads,
            self.head_dims
        )

        assertion_mkv_err = f"""
        query_states, key_states, value_states and bias shapes must be like
        query_states Shape : [batch_size, q_seq_len , {self.num_attention_heads=}, {self.head_dims=}]
        key_states   Shape : [batch_size, kv_seq_len, {self.num_attention_heads=}, {self.head_dims=}]
        value_states Shape : [batch_size, kv_seq_len, {self.num_attention_heads=}, {self.head_dims=}]
        bias         Shape : [batch_size, {self.num_attention_heads=}, q_seq_len , kv_seq_len]
            """

        assert query_states.shape == q_shape, assertion_mkv_err + (
            f"\nMiss Match {query_states.shape} and "
            f"required Shape {q_shape}"
        )
        assert key_states.shape == k_v_req_shape, assertion_mkv_err + (
            f"\nMiss Match {key_states.shape} and "
            f"required Shape {k_v_req_shape}"
        )
        assert value_states.shape == k_v_req_shape, assertion_mkv_err + (
            f"\nMiss Match {value_states.shape} and "
            f"required Shape {k_v_req_shape}"
        )

    def __call__(
            self,
            query_states: Array,
            key_states: Array,
            value_states: Array,
            causal_mask: Optional[Array] = None,
            query_sequence_length: Optional[int] = None,
            key_value_sequence_length: Optional[int] = None,
            bias: Optional[Array] = None,
            attention_mask: Optional[Array] = None,
            segment_ids: Optional[Array] = None,
            causal: bool = True,
            deterministic: bool = False,
            dropout_rng: Optional[random.PRNGKey] = None,
            uses_cache: bool = False
    ):
        if query_sequence_length is None:
            query_sequence_length = query_states.shape[1]
        if key_value_sequence_length is None:
            key_value_sequence_length = key_states.shape[1]
        with self.mesh:
            self._check_states(
                query_states=query_states,
                key_states=key_states,
                value_states=value_states,
                query_sequence_length=query_sequence_length,
                key_value_sequence_length=key_value_sequence_length
            )
            if self.attn_mechanism == "flash":
                if segment_ids is not None:
                    warnings.warn(
                        "Flash attention don't support `segment_ids` this argument will be ignored",
                        UserWarning
                    )
                if self.attention_dropout != 0.0:
                    warnings.warn(
                        "Flash attention don't support `attention_dropout` this argument will be ignored",
                        UserWarning
                    )

                return self.flash_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    causal=causal,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )

            elif self.attn_mechanism == "vanilla":

                return self.vanilla_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    dropout_rng=dropout_rng,
                    deterministic=deterministic,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            elif self.attn_mechanism == "sharded_vanilla":
                return self.sharded_vanilla_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    dropout_rng=dropout_rng,
                    deterministic=deterministic,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            elif self.attn_mechanism == "ring":
                return self.ring_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    dropout_rng=dropout_rng,
                    deterministic=deterministic,
                    segment_ids=segment_ids,
                    attention_mask=attention_mask,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            elif self.attn_mechanism == "pallas_flash":
                return self.pallas_flash_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    query_sequence_length=query_sequence_length,
                    bias=bias,
                )
            elif self.attn_mechanism == "splash":
                if segment_ids is not None:
                    warnings.warn(
                        "Splash attention don't support `segment_ids` this argument will be ignored",
                        UserWarning
                    )
                if self.attention_dropout != 0.0:
                    warnings.warn(
                        "Splash attention don't support `attention_dropout` this argument will be ignored",
                        UserWarning
                    )
                if bias is not None:
                    warnings.warn(
                        "Splash attention don't support `bias` this argument will be ignored",
                        UserWarning
                    )

                return self.splash_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length,
                    attention_mask=attention_mask
                )
            elif self.attn_mechanism == "blockwise":
                if segment_ids is not None:
                    warnings.warn(
                        "BlockWise Attention don't support `segment_ids` this argument will be ignored",
                        UserWarning
                    )
                return self.blockwise_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    deterministic=deterministic,
                    dropout_rng=dropout_rng,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            elif self.attn_mechanism == "cudnn":
                return self.cuddn_flash_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    causal=causal,
                    deterministic=deterministic,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            elif self.attn_mechanism == "local_ring":
                if segment_ids is not None:
                    warnings.warn(
                        "LocalRing Attention don't support `segment_ids` this argument will be ignored",
                        UserWarning
                    )
                if self.attention_dropout != 0.0:
                    warnings.warn(
                        "LocalRing Attention don't support `attention_dropout` this argument will be ignored",
                        UserWarning
                    )

                return self.local_ring_attention(
                    query_states=query_states,
                    key_states=key_states,
                    value_states=value_states,
                    bias=bias,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            elif self.attn_mechanism == "wise_ring":
                if segment_ids is not None:
                    warnings.warn(
                        "WiseRing Attention don't support `segment_ids` this argument will be ignored",
                        UserWarning
                    )
                if self.attention_dropout != 0.0:
                    warnings.warn(
                        "WiseRing Attention don't support `attention_dropout` this argument will be ignored",
                        UserWarning
                    )

                return self.wise_ring_attention(
                    query_states=query_states,
                    bias=bias,
                    value_states=value_states,
                    key_states=key_states,
                    segment_ids=segment_ids,
                    query_sequence_length=query_sequence_length,
                    key_value_sequence_length=key_value_sequence_length
                )
            else:
                raise ValueError(f"Unknown Attention mechanism of {self.attn_mechanism}")

    def local_ring_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int,
            key_value_sequence_length: int,
            bias: Optional[Array] = None,
    ):
        qps, kps, vps, bps, aps, _ = self.get_partition_specs(query_sequence_length)
        attention_outputs = shard_map(
            partial(
                ring_attention_standard,
                axis_name=self.axis_name,
                scale=1 / self.sm_scale,
                float32_logits=True,
            ),
            mesh=self.mesh,
            in_specs=(qps, kps, vps, bps,),
            out_specs=aps,
            check_rep=False
        )(
            query_states, key_states, value_states, bias
        )
        return AttentionOutput(
            attention_weights=None,
            attention_outputs=attention_outputs
        )

    def ring_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int,
            key_value_sequence_length: int,
            bias: Optional[Array] = None,
            attention_mask: Optional[Array] = None,
            deterministic: bool = False,
            dropout_rng: Optional[random.PRNGKey] = None,
            segment_ids: Optional[Array] = None,
    ):
        if segment_ids is None:
            segment_ids = jnp.zeros((query_states.shape[0], query_sequence_length), dtype="i4")
        if self.scan_ring_attention and query_states.shape[1] > max(
                self.block_q,
                self.block_k
        ):
            if self.platform == "tpu":
                ring_attention_fn = ring_flash_attention_tpu
            else:
                ring_attention_fn = fjformer.pallas_operations.ring_attention
            ring_attention_sharded = shard_map(
                partial(
                    ring_attention_fn,
                    axis_name=self.axis_name,
                    float32_logits=True,
                    blockwise_kwargs=dict(
                        deterministic=deterministic,
                        dropout_rng=dropout_rng,
                        attn_pdrop=self.attention_dropout,
                        causal=True,
                        query_chunk_size=self.block_q,
                        key_chunk_size=self.block_k,
                        dtype=self.dtype,
                        policy=get_gradient_checkpoint_policy("nothing_saveable"),
                        precision=self.precision,
                        prevent_cse=not self.scan_attention_layers,
                    )
                ),
                mesh=self.mesh,
                in_specs=(
                    self.query_partition_spec,
                    self.key_partition_spec,
                    self.value_partition_spec,
                    self.bias_partition_spec,
                    PartitionSpec(("dp", "fsdp"), None),
                ),
                out_specs=self.attention_partition_spec,
                check_rep=False
            )
            attn_output = ring_attention_sharded(query_states, key_states, value_states, bias, segment_ids)
            attn_output = with_sharding_constraint(attn_output, self.attention_partition_spec)
        else:
            if self.platform != "tpu":
                warnings.warn(
                    "Using Ring attention on CPUs or GPUs are not recommended due to miss computations at the moment. "
                    "please refer to other types of attention mechanism.your are bing fell back on "
                    "`ring_attention_sharded`"
                    f" Usage conditions was\nscan_ring_attention = {self.scan_ring_attention} [MUST BE TRUE]"
                    f"\nquery_states.shape[1]({query_states.shape[1]}) > max({self.block_q},{self.block_k})"
                    f"({max(self.block_q, self.block_k)})"
                )
            query_sequence_partition = None if query_states.shape[1] == 1 else "sp"
            ring_attention_sharded = shard_map(
                partial(
                    ring_attention_standard,
                    axis_name=self.axis_name,
                    scale=self.sm_scale
                ),
                mesh=self.mesh,
                in_specs=(
                    PartitionSpec(("dp", "fsdp"), query_sequence_partition, "tp", None),
                    PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
                    PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
                    PartitionSpec(("dp", "fsdp"), None, query_sequence_partition, None)
                ),
                out_specs=PartitionSpec(("dp", "fsdp"), query_sequence_partition, "tp", None),
                check_rep=False
            )
            attn_output = ring_attention_sharded(
                query_states, key_states, value_states, attention_mask
            )
        return AttentionOutput(
            attention_weights=None,
            attention_outputs=attn_output
        )

    def wise_ring_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int,
            key_value_sequence_length: int,
            bias: Optional[Array] = None,
            deterministic: bool = False,
            dropout_rng: Optional[random.PRNGKey] = None,
            segment_ids: Optional[Array] = None
    ):
        if segment_ids is None:
            segment_ids = jnp.zeros((query_states.shape[0], query_sequence_length), dtype="i4")
        if self.scan_ring_attention and query_states.shape[1] > max(self.block_q, self.block_k):
            ring_attention_sharded = shard_map(
                partial(
                    wise_ring_attention,
                    axis_name=self.axis_name,
                    float32_logits=True,
                    block_wise_kwargs=dict(
                        deterministic=deterministic,
                        dropout_rng=dropout_rng,
                        attn_pdrop=self.attention_dropout,
                        causal=True,
                        query_chunk_size=self.block_q,
                        key_chunk_size=self.block_k,
                        dtype=self.dtype,
                        policy=get_gradient_checkpoint_policy("nothing_saveable"),
                        precision=self.precision,
                        prevent_cse=not self.scan_attention_layers,
                    )
                ),
                mesh=self.mesh,
                in_specs=(
                    self.query_partition_spec,
                    self.key_partition_spec,
                    self.value_partition_spec,
                    self.bias_partition_spec,
                    PartitionSpec(("dp", "fsdp"), "sp"),
                ),
                out_specs=self.attention_partition_spec,
                check_rep=False
            )
            attn_output = ring_attention_sharded(query_states, key_states, value_states, bias, segment_ids)
            attn_output = with_sharding_constraint(attn_output, self.attention_partition_spec)
            return AttentionOutput(
                attention_weights=None,
                attention_outputs=attn_output
            )
        else:
            seq_length = query_states.shape[1]
            chunk = seq_length > max(self.block_q, self.block_k)
            warnings.warn(
                f"generation process detected, switching to local ring attention"
                f" [CHUNK : {chunk}, SCAN : {self.scan_ring_attention}, {self.block_k=}, {self.block_q=}, {seq_length=}]"
            )
            return self.local_ring_attention(
                query_states=query_states,
                key_states=key_states,
                value_states=value_states,
                bias=bias,
                query_sequence_length=query_sequence_length,
                key_value_sequence_length=key_value_sequence_length
            )

    def vanilla_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            bias: Optional[Array] = None,
            deterministic: bool = False,
            dropout_rng: Optional[random.PRNGKey] = None,
            query_sequence_length: int,
            key_value_sequence_length: int,
    ) -> AttentionOutput:
        dtype = jnp.promote_types(self.dtype, jnp.float32)
        with self.mesh:
            o, w = vanilla_attention(
                query_states=query_states,
                key_states=key_states,
                value_states=value_states,
                bias=bias,
                deterministic=deterministic,
                dtype=dtype,
                dropout_rng=dropout_rng,
                precision=self.precision,
                attention_dropout=self.attention_dropout,
                shard_attention_computation=self.shard_attention_computation,
            )
            return AttentionOutput(
                attention_weights=w,
                attention_outputs=o
            )

    def blockwise_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            bias: Optional[Array] = None,
            deterministic: bool = False,
            dropout_rng: Optional[random.PRNGKey] = None,
            query_sequence_length: int,
            key_value_sequence_length: int,
    ) -> AttentionOutput:
        dtype = jnp.promote_types(self.dtype, jnp.float32)
        qps, kps, vps, bps, aps, is_gen = self.get_partition_specs(qs=query_sequence_length)
        block_size = self.get_block_size_flash_attn(query_sequence_length, key_value_sequence_length)
        with self.mesh:
            query_states = with_sharding_constraint(query_states, qps)
            key_states = with_sharding_constraint(key_states, self.key_partition_spec)
            value_states = with_sharding_constraint(value_states, self.value_partition_spec)
            bias = with_sharding_constraint(bias, bps)
            o = blockwise_attn(
                query=query_states,
                key=key_states,
                value=value_states,
                bias=bias,
                deterministic=deterministic,
                dtype=dtype,
                dropout_rng=dropout_rng,
                precision=self.precision,
                attn_pdrop=self.attention_dropout,
                key_chunk_size=block_size.block_k,
                query_chunk_size=block_size.block_q,
                prevent_cse=not self.scan_attention_layers,
                causal=True,
                float32_logits=True
            )

            o = with_sharding_constraint(o, aps)
            return AttentionOutput(
                attention_weights=None,
                attention_outputs=o
            )

    def sharded_vanilla_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            bias: Optional[Array] = None,
            deterministic: bool = False,
            dropout_rng: Optional[random.PRNGKey] = None,
            query_sequence_length: int,
            key_value_sequence_length: int,
    ) -> AttentionOutput:
        dtype = jnp.promote_types(self.dtype, jnp.float32)

        qps, kps, vps, bps, aps, is_gen = self.get_partition_specs(qs=query_sequence_length)

        with self.mesh:
            query_states = fjformer.with_sharding_constraint(query_states, qps)
            key_states = fjformer.with_sharding_constraint(key_states, kps)
            value_states = fjformer.with_sharding_constraint(value_states, vps)

            assert query_states.ndim == key_states.ndim, "q, k must have same rank."
            assert query_states.shape[:-3] == key_states.shape[:-3], "q, k batch dims must match."
            assert query_states.shape[-2] == key_states.shape[-2], "q, k num_heads must match."
            assert query_states.shape[-1] == key_states.shape[-1], "q, k depths must match."
            query_states, key_states, value_states = promote_dtype(
                query_states, key_states, value_states,
                dtype=dtype
            )

            depth = query_states.shape[-1]
            query_states = query_states / jnp.sqrt(depth).astype(dtype)
            attention_weight = jnp.einsum("...qhd,...khd->...hqk", query_states, key_states, precision=self.precision)
            if bias is not None:
                bias = fjformer.with_sharding_constraint(bias, bps)
                attention_weight = jnp.add(attention_weight, bias)

            attention_weight = jax.nn.softmax(
                attention_weight.astype(jnp.float32)
            ).astype(dtype)

            if not deterministic and self.attention_dropout > 0.0:
                keep_prob = 1.0 - self.attention_dropout
                dropout_shape = tuple([1] * (key_states.ndim - 2)) + attention_weight.shape[-2:]
                keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)  # type: ignore

                multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
                attention_weight = attention_weight * multiplier

            attention = jnp.einsum(
                "...hqk,...khd->...qhd",
                attention_weight,
                value_states,
                precision=self.precision
            )
            attention = fjformer.with_sharding_constraint(attention, aps)
            return AttentionOutput(
                attention_weights=attention_weight,
                attention_outputs=attention
            )

    def flash_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int,
            key_value_sequence_length: int,
            bias: Optional[Array] = None,
            causal: bool = False,
    ) -> AttentionOutput:

        qps, kps, vps, bps, aps, is_gen = self.get_partition_specs(qs=query_sequence_length)
        block_size = self.get_block_size_flash_attn(query_sequence_length, key_value_sequence_length)
        query_states = query_states.transpose(0, 2, 1, 3)
        key_states = key_states.transpose(0, 2, 1, 3)
        value_states = value_states.transpose(0, 2, 1, 3)

        batch_size, num_attention_heads, query_sequence_length, head_dims = query_states.shape
        if bias is not None:
            if bias.shape[1] != num_attention_heads:
                bias = bias.repeat(num_attention_heads, 1, )

        flash_func, float32_logits, _ = get_flash_attention()
        if float32_logits:
            query_states, key_states, value_states = map(
                lambda s: s.astype(jnp.float32),
                (query_states, key_states, value_states)
            )

        if self.sm_scale is None:
            self.sm_scale = 1 / math.sqrt(query_states[-1])
        attention_o = shard_map(
            partial(
                flash_func,
                causal=causal,
                sm_scale=self.sm_scale,
                block_sizes=block_size,
                debug=False
            ),
            in_specs=(qps, kps, vps, bps),
            out_specs=aps,
            mesh=self.mesh,
            check_rep=False,
        )(
            query_states,
            key_states,
            value_states,
            bias,
        )

        attention_o = attention_o.transpose(0, 2, 1, 3)
        return AttentionOutput(
            attention_outputs=attention_o,
            attention_weights=None
        )

    def splash_attention(
            self,
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int,
            key_value_sequence_length: int,
            attention_mask: Array
    ) -> AttentionOutput:

        qps, kps, vps, bps, aps, is_gen = self.get_partition_specs(qs=query_sequence_length)

        query_states = query_states.transpose(0, 2, 1, 3)
        key_states = key_states.transpose(0, 2, 1, 3)
        value_states = value_states.transpose(0, 2, 1, 3)

        query_states, key_states, value_states = map(
            lambda s: s.astype(jnp.float32),
            (query_states, key_states, value_states)
        )
        if attention_mask is not None:
            if attention_mask.ndim == 4:
                attention_mask = attention_mask[:, 0, -1]
            attention_mask = SegmentIds(attention_mask, attention_mask)
        else:
            warnings.warn("`attention_mask` is not passed to SplashAttention. (except miss computation problem)")

        @partial(
            shard_map,
            in_specs=(qps, kps, vps, PartitionSpec(qps[0], qps[2])),  # make it easier
            out_specs=qps,
            mesh=self.mesh,
            check_rep=False,
        )
        def splash_attention_call(q, k, v, am):
            block_size = self.get_block_size_splash_attn(query_sequence_length, key_value_sequence_length)
            masks = [CausalMask(shape=(q.shape[2], k.shape[2])) for _ in range(q.shape[1])]
            multi_head_mask = MultiHeadMask(masks=masks)
            splash_kernel = make_splash_mha(
                mask=multi_head_mask,
                head_shards=1,
                q_seq_shards=1,
                block_sizes=block_size
            )

            return jax.vmap(splash_kernel)(q, k, v, segment_ids=am)

        attention_o = splash_attention_call(query_states, key_states, value_states, attention_mask)

        attention_o = attention_o.transpose(0, 2, 1, 3)
        return AttentionOutput(
            attention_outputs=attention_o,
            attention_weights=None
        )

    def pallas_flash_attention(
            self,
            *,
            query_states: Array,
            key_states: Array,
            value_states: Array,
            query_sequence_length: int = None,
            bias: Optional[Array] = None,
    ) -> AttentionOutput:
        if query_sequence_length is None:
            query_sequence_length = query_states.shape[1]
        qps, kps, vps, bps, aps, is_gen = self.get_partition_specs(qs=query_sequence_length)

        query_states, key_states, value_states = map(
            lambda s: s.astype(jnp.float32),
            (query_states, key_states, value_states)
        )
        query_states = with_sharding_constraint(query_states, qps)
        key_states = with_sharding_constraint(key_states, kps)
        value_states = with_sharding_constraint(value_states, vps)
        bias = with_sharding_constraint(bias, bps)
        attention_outputs = flash_attention(
            query_states,
            key_states,
            value_states,
            bias=bias,
            sm_scale=self.sm_scale,
            block_k=self.block_k,
            block_q=self.block_q,
            interpret=True if self.platform == "cpu" else None,  # auto-decide
            backward_pass_impl=self.backward_pass_impl
        )
        attention_outputs = with_sharding_constraint(attention_outputs, aps)
        return AttentionOutput(
            attention_weights=None,
            attention_outputs=attention_outputs
        )

    def cuddn_flash_attention(
            self,
            *,  # it's Kwarg Only
            query_states: Array,
            key_states: Array,
            value_states: Array,
            bias: Optional[Array] = None,
            causal: bool = False,
            deterministic: bool = True,
            query_sequence_length: int,
            key_value_sequence_length: int,
    ) -> AttentionOutput:
        """CUDNN Flash Attention with Transformer Engine."""
        try:
            import transformer_engine.jax.fused_attn as fused_attn
            from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
            from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
        except (ModuleNotFoundError, ImportError) as err:
            raise RuntimeError(
                "Please install transformer_engine first. you can install that by running "
                f"`pip install git+https://github.com/NVIDIA/TransformerEngine`"
                f"\nhere's extra information on error\n{err}"
            )
        batch, query_sequence_length, num_attention_heads, head_dim = query_states.shape

        qkv_layout = QKVLayout.BS3HD
        attn_mask_type = AttnMaskType.CAUSAL_MASK
        attn_bias_type = AttnBiasType.NO_BIAS

        if self.sm_scale is None:
            self.sm_scale = 1 / math.sqrt(head_dim)
        has_fused_attn_kernel = is_fused_attn_kernel_available(
            self.dtype, self.dtype, qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            key_states.shape[2],
            query_sequence_length,
            key_value_sequence_length,
            head_dim
        )

        if not has_fused_attn_kernel:
            raise ValueError(
                "Flash attention kernel is not supported for current requested arrays"
                " for details check this repo https://github.com/NVIDIA/TransformerEngine/"
            )

        return AttentionOutput(
            attention_weights=None,
            attention_outputs=fused_attn.self_fused_attn(
                qkv=jnp.concatenate(
                    (
                        jnp.reshape(query_states, (*query_states.shape[:2], 1, *query_states.shape[-2:])),
                        jnp.reshape(key_states, (*query_states.shape[:2], 1, *query_states.shape[-2:])),
                        jnp.reshape(value_states, (*query_states.shape[:2], 1, *query_states.shape[-2:]))
                    ),
                    axis=2
                ),
                bias=bias,
                mask=jnp.zeros((batch, 1, query_sequence_length, key_value_sequence_length)) if causal else None,
                seed=None,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=self.sm_scale,
                dropout_probability=self.attention_dropout,
                is_training=deterministic
            )
        )

    @staticmethod
    def test_attentions(
            batch_size=8,
            sequence_length=128 * 8,
            num_attention_heads=32,
            num_key_value_heads=32,
            chunk_size=128,
            axis_dims=(1, -1, 1, 1)
    ):
        """
        creates a test for attention module to help you find the best attention mechanism you can use.
        """
        import flax
        try:
            import pandas
        except (ModuleNotFoundError, ImportError):
            warnings.warn("couldn't import pandas ... please install pandas")
            pandas = None
        from ..modules.mistral import MistralConfig
        from fjformer import GenerateRNG
        head_dim = 128
        rng = GenerateRNG()

        config = MistralConfig(
            axis_dims=axis_dims,
            block_q=chunk_size,
            block_k=chunk_size
        )

        def value_and_grad_wrapper(fn, **kwargs):
            @partial(jax.value_and_grad, **kwargs)
            def inner(*args, **kwargs):
                return jnp.sum(fn(*args, **kwargs))

            return inner

        def diff(t1, t2):
            return jnp.max(jnp.abs(t1 - t2))

        @value_and_grad_wrapper
        def call_dot_product(q, k, v, b, ):
            attention_pred = flax.linen.dot_product_attention(q, k, v, b, )
            return attention_pred

        @value_and_grad_wrapper
        def call_attention_module(q, k, v, b, a, attn_mechanism):
            attention_pred = AttentionModule(
                attn_mechanism=attn_mechanism,
                axis_name="sp",
                dtype=jnp.float32,
                mesh=config.jax_mesh(),
                head_dims=q.shape[-1],
                num_attention_heads=q.shape[-2],
                block_q=config.block_q,
                block_k=config.block_k
            )(
                query_states=q,
                key_states=k,
                value_states=v,
                bias=b,
                attention_mask=a
            ).attention_outputs
            return attention_pred

        def make_inputs():
            q = jax.random.normal(rng.rng, (batch_size, sequence_length, num_attention_heads, head_dim),
                                  dtype="float32")
            k = jax.random.normal(rng.rng, (batch_size, sequence_length, num_key_value_heads, head_dim),
                                  dtype="float32")
            v = jax.random.normal(rng.rng, (batch_size, sequence_length, num_key_value_heads, head_dim),
                                  dtype="float32")
            c = flax.linen.attention.make_causal_mask(jnp.ones((batch_size, sequence_length)))
            a = jnp.ones((batch_size, sequence_length))
            a = a.at[:, sequence_length // 2:].set(0)
            b = jnp.where(flax.linen.attention.combine_masks(jnp.expand_dims(jnp.expand_dims(a, 1), 1), c), 0, -jnp.inf)

            return q, k, v, b, a

        q, k, v, b, a = make_inputs()
        excepted_output, excepted_grads = call_dot_product(q, k, v, b)
        test_attentions = [
            "local_ring",
            "blockwise",
            "vanilla",
            "wise_ring",
            "sharded_vanilla",
            "flash",
            "splash",
            "cudnn",
            "pallas_flash"
        ]
        fns = {
            k: partial(call_attention_module, attn_mechanism=k) for k in test_attentions
        }
        outs_and_grads = {}
        for nm, fn in fns.items():
            try:
                start = time.time()
                out = jax.block_until_ready(fn(q, k, v, b, a))
                end = time.time() - start
                outs_and_grads[nm] = out + (end,)
            except Exception as e:
                print(f"{nm} is Failed :\n\n{e}")
                outs_and_grads[nm] = (None, None, None)
        frame_out = {}
        for key, (out, grad, time_took) in outs_and_grads.items():

            if out is None and grad is None:
                frame_out[key.upper()] = {
                    "OUT DIFF": "NA",
                    "GRADIENT DIFF SUM": "NA",
                    "TEST PASSED": "NA",
                    "COMP TIME": "NA"
                }
            else:
                output_diff = diff(excepted_output, out)
                g_diff = [diff(*args) for args in zip(excepted_grads, grad)]
                sum_g = sum(g_diff)
                # TODO : Fix this
                # XlaRuntimeError: FAILED_PRECONDITION: The program continuator has halted unexpectedly.
                # sum_g = jax.device_get(sum_g)
                # output_diff = jax.device_get(output_diff)
                frame_out[key.upper()] = {
                    "OUT DIFF": output_diff,
                    "GRADIENT DIFF SUM": sum_g,
                    "TEST PASSED": sum_g < 1 and output_diff < 1e-2,
                    "COMP TIME": time_took
                }
        if pandas is not None:
            result = pandas.DataFrame.from_dict(frame_out)
            result = result.transpose()
            return result
        else:
            return frame_out

cuddn_flash_attention(*, query_states, key_states, value_states, bias=None, causal=False, deterministic=True, query_sequence_length, key_value_sequence_length)

CUDNN Flash Attention with Transformer Engine.

Source code in src/python/easydel/modules/attention_module.py
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
def cuddn_flash_attention(
        self,
        *,  # it's Kwarg Only
        query_states: Array,
        key_states: Array,
        value_states: Array,
        bias: Optional[Array] = None,
        causal: bool = False,
        deterministic: bool = True,
        query_sequence_length: int,
        key_value_sequence_length: int,
) -> AttentionOutput:
    """CUDNN Flash Attention with Transformer Engine."""
    try:
        import transformer_engine.jax.fused_attn as fused_attn
        from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
        from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
    except (ModuleNotFoundError, ImportError) as err:
        raise RuntimeError(
            "Please install transformer_engine first. you can install that by running "
            f"`pip install git+https://github.com/NVIDIA/TransformerEngine`"
            f"\nhere's extra information on error\n{err}"
        )
    batch, query_sequence_length, num_attention_heads, head_dim = query_states.shape

    qkv_layout = QKVLayout.BS3HD
    attn_mask_type = AttnMaskType.CAUSAL_MASK
    attn_bias_type = AttnBiasType.NO_BIAS

    if self.sm_scale is None:
        self.sm_scale = 1 / math.sqrt(head_dim)
    has_fused_attn_kernel = is_fused_attn_kernel_available(
        self.dtype, self.dtype, qkv_layout,
        attn_bias_type,
        attn_mask_type,
        self.attention_dropout,
        self.num_attention_heads,
        key_states.shape[2],
        query_sequence_length,
        key_value_sequence_length,
        head_dim
    )

    if not has_fused_attn_kernel:
        raise ValueError(
            "Flash attention kernel is not supported for current requested arrays"
            " for details check this repo https://github.com/NVIDIA/TransformerEngine/"
        )

    return AttentionOutput(
        attention_weights=None,
        attention_outputs=fused_attn.self_fused_attn(
            qkv=jnp.concatenate(
                (
                    jnp.reshape(query_states, (*query_states.shape[:2], 1, *query_states.shape[-2:])),
                    jnp.reshape(key_states, (*query_states.shape[:2], 1, *query_states.shape[-2:])),
                    jnp.reshape(value_states, (*query_states.shape[:2], 1, *query_states.shape[-2:]))
                ),
                axis=2
            ),
            bias=bias,
            mask=jnp.zeros((batch, 1, query_sequence_length, key_value_sequence_length)) if causal else None,
            seed=None,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=self.sm_scale,
            dropout_probability=self.attention_dropout,
            is_training=deterministic
        )
    )

test_attentions(batch_size=8, sequence_length=128 * 8, num_attention_heads=32, num_key_value_heads=32, chunk_size=128, axis_dims=(1, -1, 1, 1)) staticmethod

creates a test for attention module to help you find the best attention mechanism you can use.

Source code in src/python/easydel/modules/attention_module.py
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
@staticmethod
def test_attentions(
        batch_size=8,
        sequence_length=128 * 8,
        num_attention_heads=32,
        num_key_value_heads=32,
        chunk_size=128,
        axis_dims=(1, -1, 1, 1)
):
    """
    creates a test for attention module to help you find the best attention mechanism you can use.
    """
    import flax
    try:
        import pandas
    except (ModuleNotFoundError, ImportError):
        warnings.warn("couldn't import pandas ... please install pandas")
        pandas = None
    from ..modules.mistral import MistralConfig
    from fjformer import GenerateRNG
    head_dim = 128
    rng = GenerateRNG()

    config = MistralConfig(
        axis_dims=axis_dims,
        block_q=chunk_size,
        block_k=chunk_size
    )

    def value_and_grad_wrapper(fn, **kwargs):
        @partial(jax.value_and_grad, **kwargs)
        def inner(*args, **kwargs):
            return jnp.sum(fn(*args, **kwargs))

        return inner

    def diff(t1, t2):
        return jnp.max(jnp.abs(t1 - t2))

    @value_and_grad_wrapper
    def call_dot_product(q, k, v, b, ):
        attention_pred = flax.linen.dot_product_attention(q, k, v, b, )
        return attention_pred

    @value_and_grad_wrapper
    def call_attention_module(q, k, v, b, a, attn_mechanism):
        attention_pred = AttentionModule(
            attn_mechanism=attn_mechanism,
            axis_name="sp",
            dtype=jnp.float32,
            mesh=config.jax_mesh(),
            head_dims=q.shape[-1],
            num_attention_heads=q.shape[-2],
            block_q=config.block_q,
            block_k=config.block_k
        )(
            query_states=q,
            key_states=k,
            value_states=v,
            bias=b,
            attention_mask=a
        ).attention_outputs
        return attention_pred

    def make_inputs():
        q = jax.random.normal(rng.rng, (batch_size, sequence_length, num_attention_heads, head_dim),
                              dtype="float32")
        k = jax.random.normal(rng.rng, (batch_size, sequence_length, num_key_value_heads, head_dim),
                              dtype="float32")
        v = jax.random.normal(rng.rng, (batch_size, sequence_length, num_key_value_heads, head_dim),
                              dtype="float32")
        c = flax.linen.attention.make_causal_mask(jnp.ones((batch_size, sequence_length)))
        a = jnp.ones((batch_size, sequence_length))
        a = a.at[:, sequence_length // 2:].set(0)
        b = jnp.where(flax.linen.attention.combine_masks(jnp.expand_dims(jnp.expand_dims(a, 1), 1), c), 0, -jnp.inf)

        return q, k, v, b, a

    q, k, v, b, a = make_inputs()
    excepted_output, excepted_grads = call_dot_product(q, k, v, b)
    test_attentions = [
        "local_ring",
        "blockwise",
        "vanilla",
        "wise_ring",
        "sharded_vanilla",
        "flash",
        "splash",
        "cudnn",
        "pallas_flash"
    ]
    fns = {
        k: partial(call_attention_module, attn_mechanism=k) for k in test_attentions
    }
    outs_and_grads = {}
    for nm, fn in fns.items():
        try:
            start = time.time()
            out = jax.block_until_ready(fn(q, k, v, b, a))
            end = time.time() - start
            outs_and_grads[nm] = out + (end,)
        except Exception as e:
            print(f"{nm} is Failed :\n\n{e}")
            outs_and_grads[nm] = (None, None, None)
    frame_out = {}
    for key, (out, grad, time_took) in outs_and_grads.items():

        if out is None and grad is None:
            frame_out[key.upper()] = {
                "OUT DIFF": "NA",
                "GRADIENT DIFF SUM": "NA",
                "TEST PASSED": "NA",
                "COMP TIME": "NA"
            }
        else:
            output_diff = diff(excepted_output, out)
            g_diff = [diff(*args) for args in zip(excepted_grads, grad)]
            sum_g = sum(g_diff)
            # TODO : Fix this
            # XlaRuntimeError: FAILED_PRECONDITION: The program continuator has halted unexpectedly.
            # sum_g = jax.device_get(sum_g)
            # output_diff = jax.device_get(output_diff)
            frame_out[key.upper()] = {
                "OUT DIFF": output_diff,
                "GRADIENT DIFF SUM": sum_g,
                "TEST PASSED": sum_g < 1 and output_diff < 1e-2,
                "COMP TIME": time_took
            }
    if pandas is not None:
        result = pandas.DataFrame.from_dict(frame_out)
        result = result.transpose()
        return result
    else:
        return frame_out

get_flash_attention()

return: FlashAttention FN, Upcast Needed to float32,do_shard_map

Source code in src/python/easydel/modules/attention_module.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def get_flash_attention() -> Tuple[Callable, bool, bool]:
    """
    return: FlashAttention FN, Upcast Needed to float32,do_shard_map
    """
    platform = jax.lib.xla_bridge.get_backend().platform
    if platform == "gpu":
        warnings.warn("for GPU backend use `cudnn` or `pallas_flash`")
        float32_logits = False
        ring_attention_fn = flash_attention
        do_shard_map = True
    elif platform == "tpu":
        float32_logits = True
        ring_attention_fn = tpu_flash_attention
        do_shard_map = False
    else:
        raise ValueError(f"Unsupported platform {platform}")

    return ring_attention_fn, float32_logits, do_shard_map