modules.falcon.modelling_falcon_flax
FlaxFalconPretrainedModel
Bases: EasyDeLFlaxPretrainedModel
Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 |
|
init_weights(rng, input_shape, params=None)
The init_weights function is used to initialize the weights of a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Access variables that belong to the class |
required | |
rng |
PRNGKey
|
jax.random.PRNGKey: Initialize the weights of the model |
required |
input_shape |
Tuple
|
Tuple: Specify the shape of the input tensor |
required |
params |
FrozenDict
|
FrozenDict: Pass in the parameters of a pre-trained model |
None
|
Returns:
Type | Description |
---|---|
FrozenDict
|
A frozendict of parameters |
Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 |
|
apply_rotary_pos_embedding(tensor, sin_, cos_)
The apply_rotary_pos_embedding function applies a rotary positional embedding to the input tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Pass in the tensor that we want to apply the positional embedding to |
required | |
sin_ |
Rotate the tensor by half of its length |
required | |
cos_ |
Multiply the tensor and cosine of the angle |
required |
Returns:
Type | Description |
---|---|
A tensor with the same shape as its input, |
Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
94 95 96 97 98 99 100 101 102 103 104 |
|
built_bloom_alibi(attention_mask, num_attention_heads)
The built_bloom_alibi function is used to create a bloom alibi for the attention mask. The bloom alibi is used in the Bloom Attention layer to ensure that each token has a unique attention vector, even if it's masked out. This ensures that all tokens have an equal chance of being selected as the most important token in the sequence, which helps with training stability and performance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
attention_mask |
Mask out the padding tokens in the input sequence |
required | |
num_attention_heads |
Determine the number of attention heads in the model |
required |
Returns:
Type | Description |
---|---|
A tensor of shape (batch_size, num_attention_heads, 1, sequence_length) |
Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
dropout_add(linen_drop, x, residual, deterministic)
The dropout_add function is a helper function that adds the residual to the output of the dropout layer. This is necessary because we want to use deterministic=True when we are evaluating our model, but we still need to add in the residual. The reason for this is that during training, we have two paths through our network: one with dropout and one without. The path without dropout (residual) allows us to backpropagate gradients through both paths at once.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
linen_drop |
Dropout
|
flax.linen.Dropout: Specify the dropout layer |
required |
x |
Array
|
chex.Array: Pass in the input to the dropout layer |
required |
residual |
Array
|
chex.Array: Add the residual to the output of dropout_add |
required |
deterministic |
bool
|
bool: Determine whether the dropout layer is active or not |
required |
Returns:
Type | Description |
---|---|
Array
|
A tensor that is the sum of the residual and a dropout layer |
Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|
precompute_falcon_freq_cis(max_position_embedding, head_dim, theta=10000)
The precompute_falcon_freq_cis function is used to precompute the sinusoidal frequencies for the FALCON model. The function takes in three arguments: max_position_embedding, head_dim, and theta. The first two are self-explanatory; the third is a hyperparameter that controls how quickly the frequency increases with position (i.e., how many times higher it will be at position i than at position 0). The default value of 10000 was chosen because it worked well on the tasks we tested.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_position_embedding |
int
|
int: Set the maximum length of the sequence |
required |
head_dim |
int
|
int: Determine the size of the positional embedding |
required |
theta |
float
|
float: Adjust the frequency of the sinusoid |
10000
|
Returns:
Type | Description |
---|---|
A tuple of two arrays |
Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|