Skip to content

serve.utils

Seafoam

Bases: Base

Source code in src/python/easydel/serve/utils.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class Seafoam(Base):
    def __init__(
            self,
            *,
            primary_hue: Union[colors.Color, str] = colors.emerald,
            secondary_hue: Union[colors.Color, str] = colors.blue,
            neutral_hue: Union[colors.Color, str] = colors.gray,
            spacing_size: Union[sizes.Size, str] = sizes.spacing_md,
            radius_size: Union[sizes.Size, str] = sizes.radius_md,
            text_size: Union[sizes.Size, str] = sizes.text_lg,
            font: Union[fonts.Font, str]
            = (
                    fonts.GoogleFont("Quicksand"),
                    "ui-sans-serif",
                    "sans-serif",
            ),
            font_mono: Union[fonts.Font, str]
            = (
                    fonts.GoogleFont("IBM Plex Mono"),
                    "ui-monospace",
                    "monospace",
            ),
    ):
        """
        The __init__ function is called when the class is instantiated.
        It sets up the object with all of its instance variables and other things it needs to function properly.


        :param self: Represent the instance of the object
        :param *: Unpack the list of parameters into a tuple
        :param primary_hue: Union[colors.Color,str]: Set the primary color of the theme
        :param secondary_hue: Union[colors.Color,str]: Set the secondary color of the theme
        :param neutral_hue: Union[colors.Color,str]: Set the neutral color of the theme
        :param spacing_size: Union[sizes.Size,str]: Set the spacing size of the theme
        :param radius_size: Union[sizes.Size,str]: Set the radius of the buttons and other elements
        :param text_size: Union[sizes.Size,str]: Set the size of the text in the app

        :return: The class object

        """

        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,

        )
        super().set(
            body_background_fill="linear-gradient(90deg, *secondary_800, *neutral_900)",
            body_background_fill_dark="linear-gradient(90deg, *secondary_800, *neutral_900)",
            button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
            button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
            button_primary_text_color="white",
            button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
            slider_color="*secondary_300",
            slider_color_dark="*secondary_400",
            block_title_text_weight="600",
            block_border_width="0px",
            block_shadow="*shadow_drop_lg",
            button_shadow="*shadow_drop_lg",
            button_large_padding="4px",
            border_color_primary="linear-gradient(90deg, *primary_600, *secondary_800)",
            border_color_primary_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
            table_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
            table_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
            button_primary_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
            button_primary_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
            panel_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
            panel_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
            block_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
            block_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)"
        )

__init__(*, primary_hue=colors.emerald, secondary_hue=colors.blue, neutral_hue=colors.gray, spacing_size=sizes.spacing_md, radius_size=sizes.radius_md, text_size=sizes.text_lg, font=(fonts.GoogleFont('Quicksand'), 'ui-sans-serif', 'sans-serif'), font_mono=(fonts.GoogleFont('IBM Plex Mono'), 'ui-monospace', 'monospace'))

The init function is called when the class is instantiated. It sets up the object with all of its instance variables and other things it needs to function properly.

Parameters:

Name Type Description Default
self

Represent the instance of the object

required
*

Unpack the list of parameters into a tuple

required
primary_hue Union[Color, str]

Union[colors.Color,str]: Set the primary color of the theme

emerald
secondary_hue Union[Color, str]

Union[colors.Color,str]: Set the secondary color of the theme

blue
neutral_hue Union[Color, str]

Union[colors.Color,str]: Set the neutral color of the theme

gray
spacing_size Union[Size, str]

Union[sizes.Size,str]: Set the spacing size of the theme

spacing_md
radius_size Union[Size, str]

Union[sizes.Size,str]: Set the radius of the buttons and other elements

radius_md
text_size Union[Size, str]

Union[sizes.Size,str]: Set the size of the text in the app

text_lg

Returns:

Type Description

The class object

Source code in src/python/easydel/serve/utils.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
        self,
        *,
        primary_hue: Union[colors.Color, str] = colors.emerald,
        secondary_hue: Union[colors.Color, str] = colors.blue,
        neutral_hue: Union[colors.Color, str] = colors.gray,
        spacing_size: Union[sizes.Size, str] = sizes.spacing_md,
        radius_size: Union[sizes.Size, str] = sizes.radius_md,
        text_size: Union[sizes.Size, str] = sizes.text_lg,
        font: Union[fonts.Font, str]
        = (
                fonts.GoogleFont("Quicksand"),
                "ui-sans-serif",
                "sans-serif",
        ),
        font_mono: Union[fonts.Font, str]
        = (
                fonts.GoogleFont("IBM Plex Mono"),
                "ui-monospace",
                "monospace",
        ),
):
    """
    The __init__ function is called when the class is instantiated.
    It sets up the object with all of its instance variables and other things it needs to function properly.


    :param self: Represent the instance of the object
    :param *: Unpack the list of parameters into a tuple
    :param primary_hue: Union[colors.Color,str]: Set the primary color of the theme
    :param secondary_hue: Union[colors.Color,str]: Set the secondary color of the theme
    :param neutral_hue: Union[colors.Color,str]: Set the neutral color of the theme
    :param spacing_size: Union[sizes.Size,str]: Set the spacing size of the theme
    :param radius_size: Union[sizes.Size,str]: Set the radius of the buttons and other elements
    :param text_size: Union[sizes.Size,str]: Set the size of the text in the app

    :return: The class object

    """

    super().__init__(
        primary_hue=primary_hue,
        secondary_hue=secondary_hue,
        neutral_hue=neutral_hue,
        spacing_size=spacing_size,
        radius_size=radius_size,
        text_size=text_size,
        font=font,
        font_mono=font_mono,

    )
    super().set(
        body_background_fill="linear-gradient(90deg, *secondary_800, *neutral_900)",
        body_background_fill_dark="linear-gradient(90deg, *secondary_800, *neutral_900)",
        button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
        button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
        button_primary_text_color="white",
        button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
        slider_color="*secondary_300",
        slider_color_dark="*secondary_400",
        block_title_text_weight="600",
        block_border_width="0px",
        block_shadow="*shadow_drop_lg",
        button_shadow="*shadow_drop_lg",
        button_large_padding="4px",
        border_color_primary="linear-gradient(90deg, *primary_600, *secondary_800)",
        border_color_primary_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
        table_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
        table_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
        button_primary_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
        button_primary_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
        panel_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
        panel_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
        block_border_color="linear-gradient(90deg, *primary_600, *secondary_800)",
        block_border_color_dark="linear-gradient(90deg, *primary_600, *secondary_800)"
    )

create_generate_function(model, generation_config, params, generation_partition_spec=PartitionSpec(('dp', 'fsdp'), 'sp'), output_partition_spec=PartitionSpec(('dp', 'fsdp'), 'sp'), logits_processor=None, return_prediction_only=True)

Create a sharded function for text generation using a Flax model.

:param model :EasyDeLFlaxPretrainedModel: The Flax model used for text generation.
:param generation_config :GenerationConfig: Configuration for text generation.
:param params :dict or jax.tree_util.PyTreeDef: Parameters of the model or a PyTree representing the model's
    parameters.
:param generation_partition_spec :PartitionSpec: Sharding specification for generation inputs. Defaults to
    PartitionSpec(("dp", "fsdp"), "sp").
:param output_partition_spec: PartitionSpec: Sharding specification for output sequences. Defaults to
    PartitionSpec(("dp", "fsdp"), "sp").
:param logits_processor :LogitsProcessor: Processor for model logits. Defaults to None.
:param return_prediction_only :bool: Whether to return only the generated sequences. Defaults to True.

Returns:

Type Description
Callable[[Union[dict, PyTreeDef], Array, Array], Array]

Callable[[Any, chex.Array, chex.Array], chex.Array]: Sharded function for text generation.

Source code in src/python/easydel/serve/utils.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def create_generate_function(
        model: EasyDeLFlaxPretrainedModel,
        generation_config: GenerationConfig,
        params: Union[dict, jax.tree_util.PyTreeDef],
        generation_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp"),
        output_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp"),
        logits_processor: Optional[LogitsProcessor] = None,
        return_prediction_only: bool = True
) -> Callable[[Union[dict, jax.tree_util.PyTreeDef], chex.Array, chex.Array], chex.Array]:
    """Create a sharded function for text generation using a Flax model.

        :param model :EasyDeLFlaxPretrainedModel: The Flax model used for text generation.
        :param generation_config :GenerationConfig: Configuration for text generation.
        :param params :dict or jax.tree_util.PyTreeDef: Parameters of the model or a PyTree representing the model's
            parameters.
        :param generation_partition_spec :PartitionSpec: Sharding specification for generation inputs. Defaults to
            PartitionSpec(("dp", "fsdp"), "sp").
        :param output_partition_spec: PartitionSpec: Sharding specification for output sequences. Defaults to
            PartitionSpec(("dp", "fsdp"), "sp").
        :param logits_processor :LogitsProcessor: Processor for model logits. Defaults to None.
        :param return_prediction_only :bool: Whether to return only the generated sequences. Defaults to True.

    :return :Callable[[Any, chex.Array, chex.Array], chex.Array]: Sharded function for text generation.

    """

    def generate_fn(
            parameters: Union[dict, jax.tree_util.PyTreeDef],
            input_ids: chex.Array,
            attention_mask: chex.Array
    ) -> chex.Array:
        """Generate text sequences using the provided model and parameters.

        :param parameters:Union[dict, jax.tree_util.PyTreeDef]: Model parameters.
        :param input_ids: chex.Array: Input token IDs.
        :param attention_mask:chex.Array: Attention mask.
        :return: Generated array sequences.
        """
        input_ids = with_sharding_constraint(
            input_ids,
            generation_partition_spec
        )
        attention_mask = with_sharding_constraint(
            attention_mask,
            generation_partition_spec
        )
        predict = model.generate(
            input_ids,
            attention_mask=attention_mask,
            params=parameters,
            generation_config=generation_config,
            logits_processor=logits_processor
        )
        return predict.sequences[:, input_ids.shape[1]:] if return_prediction_only else predict.sequences

    return pjit(
        generate_fn,
        in_shardings=(
            jax.tree_util.tree_map(get_partitions, params),
            generation_partition_spec,
            generation_partition_spec
        ),
        out_shardings=output_partition_spec
    )

get_partitions(tree)

Retrieve sharding specifications for model parameters.

Source code in src/python/easydel/serve/utils.py
120
121
122
123
124
125
126
127
128
129
130
def get_partitions(tree):
    """Retrieve sharding specifications for model parameters."""
    if not isinstance(tree, fjformer.linen.LinearBitKernel):
        return getattr(tree.sharding, "spec", PartitionSpec(None))
    else:
        kernel_sharding = getattr(tree.kernel.sharding, "spec", PartitionSpec(None))
        scale_sharding = getattr(tree.scale.sharding, "spec", PartitionSpec(None))
        return fjformer.linen.LinearBitKernel(
            kernel=kernel_sharding,  # type:ignore
            scale=scale_sharding,  # type:ignore
        )