28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167 | class FlaxVisionLlamaPreTrainedModel(EasyDeLFlaxPretrainedModel):
config_class = VisionLlamaConfig
base_model_prefix = "model"
module_class: nn.Module = None
def __init__(
self,
config: VisionLlamaConfig,
input_shape: Tuple = (4, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_cache(self, batch_size, max_length):
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
vision_mask = jnp.ones((batch_size, max_length), dtype=bool)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, vision_mask, attention_mask, position_ids,
return_dict=False, init_cache=True
)
return init_variables["cache"]
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
"""
The init_weights function is used to initialize the weights of a model.
:param self: Access variables that belong to the class
:param rng: jax.random.PRNGKey: Initialize the weights of the model
:param input_shape: Tuple: Specify the shape of the input tensor
:param params: FrozenDict: Pass in the parameters of a pre-trained model
:return: A frozendict of parameters
"""
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
vision_mask = jnp.ones(input_ids.shape, dtype=bool)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
random_params = self.module.init(
{
"params": params_rng,
"dropout": dropout_rng
},
input_ids,
vision_mask,
attention_mask,
position_ids,
return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
input_ids: chex.Array,
vision_mask: Optional[chex.Array] = None,
attention_mask: Optional[chex.Array] = None,
position_ids: Optional[chex.Array] = None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
add_params_field: bool = False,
**kwargs
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_ids.shape
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(vision_mask, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
|