Skip to content

modules.opt.modelling_opt_flax

Flax OPT model.

FlaxOPTLearnedPositionalEmbedding

Bases: Embed

Source code in src/python/easydel/modules/opt/modelling_opt_flax.py
326
327
328
329
330
331
332
333
334
335
336
337
class FlaxOPTLearnedPositionalEmbedding(nn.Embed):

    def setup(self):
        self.offset = 2
        self.embedding = self.param(
            "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype
        )

    def __call__(self, positions):
        """`input_ids_shape` is expected to be [bsz x seqlen]."""

        return super().__call__(positions + self.offset)

__call__(positions)

input_ids_shape is expected to be [bsz x seqlen].

Source code in src/python/easydel/modules/opt/modelling_opt_flax.py
334
335
336
337
def __call__(self, positions):
    """`input_ids_shape` is expected to be [bsz x seqlen]."""

    return super().__call__(positions + self.offset)