Skip to content

xrapture.implicit_array

ArrayValue

Helper class that provides a standard way to create an ABC using inheritance.

Source code in src/fjformer/xrapture/implicit_array.py
46
47
48
49
50
51
52
53
class ArrayValue(metaclass=ABCMeta):
    """Helper class that provides a standard way to create an ABC using
    inheritance.
    """
    __slots__ = ()
    shape = None
    e_num_val = None
    is_registered_by_pJit = False

Complement

Relative complement I.e. Complement[A, B] = A - B

Source code in src/fjformer/xrapture/implicit_array.py
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
@parametric
class Complement(metaclass=_ComplementMeta):
    """
    Relative complement
    I.e. Complement[A, B] = A - B
    """

    @classmethod
    @dispatch
    def __init_type_parameter__(
            cls,
            a: Optional[Any],
            b: Optional[Any],
    ):
        return a, b

    @classmethod
    @dispatch
    def __le_type_parameter__(
            cls,
            left: Tuple[Optional[Any], Optional[Any]],
            right: Tuple[Optional[Any], Optional[Any]],
    ):
        a_left, b_left = left
        a_right, b_right = right

        return issubclass(a_left, a_right) and issubclass(b_right, b_left)

ImplicitArray dataclass

Bases: _ImplicitArrayBase

Abstract class for representing an abstract array of a given shape/dtype without actually instantiating it. Subclasses must implement the materialize method, which defines the relationship between the implicit array and the value it represents. Subclasses are valid arguments to functions decorated with qax.use_implicit_args.

All subclasses are automatically registered as pytrees using jax.tree_util.register_pytree_with_keys_class. Any dataclass attributes added will be included as children, unless they are decorated with qax.aux_field in which case they are passed as auxiliary data during flattening.

The represented shape and dtype may be defined in any of the following ways: - Explicitly passing shape/dtype keyword arguments at initialization - Overriding the default_shape/default_dtype class variables - Overriding the compute_shape/compute_dtype methods, which are called during post_init - Overriding post_init and manually setting shape/dtype before calling super().post_init - None of the above, in which case an shape/dtype will be inferred by by running jax.eval_shape() on the subclass"s materialize method.

Source code in src/fjformer/xrapture/implicit_array.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
@dataclass
class ImplicitArray(_ImplicitArrayBase):
    """
    Abstract class for representing an abstract array of a given shape/dtype without actually instantiating it.
    Subclasses must implement the materialize method, which defines the relationship between the implicit array
    and the value it represents. Subclasses are valid arguments to functions decorated with qax.use_implicit_args.

    All subclasses are automatically registered as pytrees using jax.tree_util.register_pytree_with_keys_class.
    Any dataclass attributes added will be included as children, unless they are decorated with qax.aux_field
    in which case they are passed as auxiliary data during flattening.

    The represented shape and dtype may be defined in any of the following ways:
        - Explicitly passing shape/dtype keyword arguments at initialization
        - Overriding the default_shape/default_dtype class variables
        - Overriding the compute_shape/compute_dtype methods, which are called during __post_init__
        - Overriding __post_init__ and manually setting shape/dtype before calling super().__post_init__
        - None of the above, in which case an shape/dtype will be inferred by by running jax.eval_shape()
          on the subclass"s materialize method.
    """

    shape = _AvalDescriptor()
    dtype = _AvalDescriptor()

    def __post_init__(self):
        try:
            aval = _get_materialization_aval(self)
        except UninitializedAval:
            # Materialization depends on currently uninitialized shape/dtype
            aval = None

        shape = None
        try:
            shape = self.shape
        except UninitializedAval as e:
            shape = self.shape = self.compute_shape()

        if aval is not None:
            if shape is None:
                self.shape = aval.shape
            elif shape != aval.shape:
                warnings.warn(f"ImplicitArray shape {shape} does not match materialization shape {aval.shape}")
        elif shape is None:
            raise UninitializedAval("shape")

        dtype = None
        try:
            dtype = self.dtype
        except UninitializedAval as e:
            dtype = self.dtype = self.compute_dtype()

        if dtype is None and aval is None:
            # We have a shape but not a dtype, try once again to infer the dtype
            aval = _get_materialization_aval(self)

        if aval is not None:
            if dtype is None:
                self.dtype = aval.dtype
            elif dtype != aval.dtype:
                warnings.warn(f"ImplicitArray dtype {dtype} does not match materialization dtype {aval.dtype}")
        elif dtype is None:
            raise UninitializedAval("dtype")

    def compute_shape(self):
        """
        Override this method if the subclass instance"s shape should be computed based on its other properties.
        Returns: shape
        """
        return self.default_shape

    def compute_dtype(self):
        """
        Override this method if the subclass instance"s dtype should be computed based on its other properties.
        Returns: dtype
        """
        return self.default_dtype

    @property
    def aval(self):
        return core.ShapedArray(self.shape, self.dtype)

    @classmethod
    def default_handler(cls, primitive, *args, params=None):
        if params is None:
            params = {}
        return materialize_handler(primitive, *args, params=params)

    @abstractmethod
    def materialize(self):
        pass

    def tree_flatten_with_keys(self):
        children = []
        aux_data = []
        for name, is_aux in _get_names_and_aux(self):
            try:
                value = getattr(self, name)
            except UninitializedAval:
                if not _aval_discovery.get():
                    raise
                value = None
            if is_aux:
                aux_data.append(value)
            else:
                children.append((name, value))

        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        child_it = iter(children)
        aux_it = iter(aux_data)
        obj = cls.__new__(cls)
        for name, is_aux in _get_names_and_aux(cls):
            value = next(aux_it if is_aux else child_it)
            setattr(obj, name, value)

        return obj

    def handle_primitive(self, primitive, *args, params):
        handler = lu.wrap_init(partial(get_primitive_handler(primitive), primitive))
        use_params = params

        if len(args) == 2 and self.commute_ops:
            args, use_params = _maybe_swap_args(primitive.name, args, use_params)

        # maybe_kwargs = {"params": params} if params else {}
        flat_args, in_tree = flatten_one_implicit_layer((args, params))
        flat_handler, out_tree = flatten_fun(handler, in_tree)

        result = use_implicit_args(flat_handler.call_wrapped)(*flat_args)
        return jax.tree_util.tree_unflatten(out_tree(), result)

    def __init_subclass__(cls, commute_ops=True, **kwargs):
        super().__init_subclass__(**kwargs)

        if not is_dataclass(cls):
            raise TypeError(f"{cls.__name__} must be a dataclass")
        core.pytype_aval_mappings[cls] = lambda x: x.aval
        register_pytree_with_keys_class(cls)
        return cls

compute_dtype()

Override this method if the subclass instance"s dtype should be computed based on its other properties. Returns: dtype

Source code in src/fjformer/xrapture/implicit_array.py
335
336
337
338
339
340
def compute_dtype(self):
    """
    Override this method if the subclass instance"s dtype should be computed based on its other properties.
    Returns: dtype
    """
    return self.default_dtype

compute_shape()

Override this method if the subclass instance"s shape should be computed based on its other properties. Returns: shape

Source code in src/fjformer/xrapture/implicit_array.py
328
329
330
331
332
333
def compute_shape(self):
    """
    Override this method if the subclass instance"s shape should be computed based on its other properties.
    Returns: shape
    """
    return self.default_shape

apply_updates(params, updates)

Like optax.apply_updates, but updates can be SymbolicConstant instances

Source code in src/fjformer/xrapture/implicit_array.py
979
980
981
982
983
984
985
986
987
988
def apply_updates(params: optax.Params, updates: optax.Updates) -> optax.Params:
    """
    Like optax.apply_updates, but updates can be SymbolicConstant instances
    """
    updates_flat, update_struct = tree_util.tree_flatten(updates, is_leaf=lambda x: isinstance(x, SymbolicConstant))
    semi_flat_params = update_struct.flatten_up_to(params)

    updated_flat = use_implicit_args(optax.apply_updates)(semi_flat_params, updates_flat)
    updated = update_struct.unflatten(updated_flat)
    return updated

freeze_subtrees(optimizer, label_fn, use_scalar_zeros=False)

Utility which wraps an optimizer such that subtrees specified by label_fn will receive zeros as updates. Subtrees to be frozen should be labeled with "freeze" and all other subtrees should be labeled with "train"

Source code in src/fjformer/xrapture/implicit_array.py
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
def freeze_subtrees(optimizer: optax.GradientTransformation, label_fn, use_scalar_zeros=False):
    """
    Utility which wraps an optimizer such that subtrees specified by
    label_fn will receive zeros as updates.
    Subtrees to be frozen should be labeled with "freeze"
    and all other subtrees should be labeled with "train"
    """
    multi_transformed_optimizer = optax.multi_transform(
        {
            'freeze': set_to_zero_scalar() if use_scalar_zeros else optax.set_to_zero(),
            'train': optimizer
        },
        label_fn
    )

    def new_update(grads, opt_state, params):
        def map_float0(param, grad):
            if grad.dtype == float0:
                return jnp.zeros((), param.dtype) if use_scalar_zeros else jnp.zeros_like(param)
            return grad

        fixed_grads = jax.tree_map(map_float0, params, grads)
        return multi_transformed_optimizer.update(fixed_grads, opt_state, params)

    return optax.GradientTransformation(
        multi_transformed_optimizer.init,
        new_update
    )

get_common_prefix_transforms(trees)

Given an iterable of pytrees which have the same structure after all ImplicitArray instances are materialized, return a list of callables which will transform each tree into the largest common structure obtainable via materialization of ImplicitArrays.

Source code in src/fjformer/xrapture/implicit_array.py
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def get_common_prefix_transforms(trees):
    """
    Given an iterable of pytrees which have the same structure after all
    ImplicitArray instances are materialized, return a list of callables
    which will transform each tree into the largest common structure
    obtainable via materialization of ImplicitArrays.
    """
    if len(trees) <= 1:
        return [lambda x: x for _ in trees]

    all_leaves, structures = zip(*(tree_flatten_with_implicit(tree) for tree in trees))
    post_materialization_avals = [core.get_aval(leaf) for leaf in all_leaves[0]]
    for i, (leaves, structure) in enumerate(zip(all_leaves[1:], structures[1:]), 1):
        if structure != structures[0]:
            raise ValueError('Trees do not have the same structure after materialization')

        for leaf, expected_aval in zip(leaves, post_materialization_avals):
            aval = core.get_aval(leaf)
            if not (aval.shape == expected_aval.shape and aval.dtype == expected_aval.dtype):
                raise ValueError(
                    f'Trees do not have the same avals after materialization. Tree 0: {expected_aval}, Tree {i}: {aval}'
                )

    # Stack will contain tuples of (path, nodes)
    # path = a sequence of integers specifying which child
    # was taken at each _flatten_one_implicit_layer call
    # or the first flatten_with_implicit call
    # nodes = one node from each tree
    stack = []

    all_leaves = []
    for tree in trees:
        all_leaves.append(tree_leaves_with_implicit(tree))

    for i, nodes in enumerate(zip(*all_leaves)):
        stack.append(((i,), nodes))

    materialization_paths = set()
    while stack:
        path_prefix, nodes = stack.pop()
        if not any(isinstance(node, ImplicitArray) for node in nodes):
            continue

        all_leaves, all_structures = zip(*(
            flatten_one_implicit_layer(node) for node in nodes
        ))
        node_structures = set(all_structures)
        if len(node_structures) > 1:
            materialization_paths.add(path_prefix)
            continue

        aval_diff = False
        for leaves in zip(*all_leaves):
            first_aval = core.get_aval(leaves[0])
            shape = first_aval.shape
            dtype = first_aval.dtype
            for leaf in leaves[1:]:
                aval = core.get_aval(leaf)
                if not (aval.shape == shape and aval.dtype == dtype):
                    materialization_paths.add(path_prefix)
                    aval_diff = True
            if aval_diff:
                break

        if aval_diff:
            continue

        for i, leaf_group in enumerate(zip(*all_leaves)):
            stack.append((path_prefix + (i,), leaf_group))

    return [_get_pruning_transform(tree, materialization_paths) for tree in trees]

materialize_nested(implicit_arr, full=False)

Materialize an ImplicitArray instance, handling the case where implicit_arr.materialize() involves further ImplicitArray instances. Arguments: implicit_arr: An ImplicitArray instance full: If True, repeatedly materialize until the result is a concrete array Returns: The materialized array

Source code in src/fjformer/xrapture/implicit_array.py
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
def materialize_nested(implicit_arr, full=False):
    """
    Materialize an ImplicitArray instance, handling the case where implicit_arr.materialize()
    involves further ImplicitArray instances.
    Arguments:
        implicit_arr: An ImplicitArray instance
        full: If True, repeatedly materialize until the result is a concrete array
    Returns:
        The materialized array
    """
    while isinstance(implicit_arr, ImplicitArray):
        wrapped = lu.wrap_init(type(implicit_arr).materialize)
        flat, in_tree = flatten_one_implicit_layer((implicit_arr,))
        flat_fn, out_tree = flatten_fun_nokwargs(wrapped, in_tree)
        out_flat = use_implicit_args(flat_fn.call_wrapped)(*flat)
        implicit_arr = jax.tree_util.tree_unflatten(out_tree(), out_flat)

        if not full:
            break

    return implicit_arr

set_to_zero_scalar()

Returns a gradient transformation that sets all gradients to 0 in order to make downstream constant folding cheaper.

Source code in src/fjformer/xrapture/implicit_array.py
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
def set_to_zero_scalar() -> optax.GradientTransformation:
    """
    Returns a gradient transformation that sets all gradients to 0 in order to
    make downstream constant folding cheaper.
    """

    def init_fn(params):
        del params
        return optax.EmptyState()

    def update_fn(updates, state, params=None):
        return jax.tree_map(lambda x: jnp.zeros((), x.dtype), updates), state

    return optax.GradientTransformation(init_fn, update_fn)

use_implicit_args(f)

Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly including further ImplicitArray instances as children. Any number of arguments (including 0) may be ImplicitArrays.

Source code in src/fjformer/xrapture/implicit_array.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def use_implicit_args(f):
    """
    Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly
    including further ImplicitArray instances as children.
    Any number of arguments (including 0) may be ImplicitArrays.
    """

    @wraps(f)
    def implicit_f(*args, **kwargs):
        flat_args, in_tree = tree_flatten_with_implicit((args, kwargs))
        f_flat, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
        f_wrapped = _with_implicit_flat(f_flat)
        outs_flat = f_wrapped.call_wrapped(*flat_args)
        return out_tree().unflatten(outs_flat)

    return implicit_f

vmap_all_but_one(f, axis, out_ndim=0)

Repeatedly calls vmap to map over all axes except for axis. All args will be mapped on the same dimensions.

Source code in src/fjformer/xrapture/implicit_array.py
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
def vmap_all_but_one(f, axis, out_ndim=0):
    """
    Repeatedly calls vmap to map over all axes except for `axis.`
    All args will be mapped on the same dimensions.
    """

    @wraps(f)
    def inner(*args):
        n_dim = args[0].ndim
        if axis >= n_dim:
            raise ValueError(f'Axis {axis} is out of bounds for array of dimension {n_dim}')
        fn = f
        vmap_dim = 1
        out_dim = out_ndim
        for i in reversed(range(n_dim)):
            if i == axis:
                vmap_dim = 0
                out_dim = 0
            else:
                fn = jax.vmap(fn, vmap_dim, out_dim)
        return fn(*args)

    return inner