Skip to content

xrapture.tracer

ImplicitArrayTrace

Bases: Trace

Source code in src/fjformer/xrapture/tracer.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class ImplicitArrayTrace(Trace):
    pure = lift = lambda self, val: ImplicitArrayTracer(self, val)

    def process_primitive(self, primitive, tracers, params):

        """
        The process_primitive function is called by the tracer when it encounters a primitive.
        The function should return a list of Tracers, which will be used to replace the original
        Tracers in the trace. The process_primitive function can also modify params, which are
        the parameters passed to the primitive.

        :param self: Access the class attributes
        :param primitive: Identify the primitive operation
        :param tracers: Trace the value of each input to a primitive
        :param params: Pass in the parameters of the function
        :return: The primitive, tracers and params

        """
        vals = [t.value for t in tracers]
        n_implicit = sum(isinstance(v, ImplicitArray) for v in vals)
        assert 1 <= n_implicit <= 2
        if n_implicit == 2:
            warnings.warn(f'Encountered op {primitive.name} with two implicit inputs so second will be materialized.')
            vals[1] = vals[1].materialize()

process_primitive(primitive, tracers, params)

The process_primitive function is called by the tracer when it encounters a primitive. The function should return a list of Tracers, which will be used to replace the original Tracers in the trace. The process_primitive function can also modify params, which are the parameters passed to the primitive.

Parameters:

Name Type Description Default
self

Access the class attributes

required
primitive

Identify the primitive operation

required
tracers

Trace the value of each input to a primitive

required
params

Pass in the parameters of the function

required

Returns:

Type Description

The primitive, tracers and params

Source code in src/fjformer/xrapture/tracer.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def process_primitive(self, primitive, tracers, params):

    """
    The process_primitive function is called by the tracer when it encounters a primitive.
    The function should return a list of Tracers, which will be used to replace the original
    Tracers in the trace. The process_primitive function can also modify params, which are
    the parameters passed to the primitive.

    :param self: Access the class attributes
    :param primitive: Identify the primitive operation
    :param tracers: Trace the value of each input to a primitive
    :param params: Pass in the parameters of the function
    :return: The primitive, tracers and params

    """
    vals = [t.value for t in tracers]
    n_implicit = sum(isinstance(v, ImplicitArray) for v in vals)
    assert 1 <= n_implicit <= 2
    if n_implicit == 2:
        warnings.warn(f'Encountered op {primitive.name} with two implicit inputs so second will be materialized.')
        vals[1] = vals[1].materialize()

ImplicitArrayTracer

Bases: Tracer

Source code in src/fjformer/xrapture/tracer.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class ImplicitArrayTracer(Tracer):
    def __init__(self, trace, value):

        """
        The __init__ function is called when the class is instantiated.
        It sets up the object with all of its properties and methods.
        The self parameter refers to the instance of the object itself.

        :param self: Refer to the instance of the class
        :param trace: Store the traceback object, which is used to print out a stack trace
        :param value: Store the value of the exception
        :return: The value of the class
        """
        super().__init__(trace)
        self.value = value

    @property
    def aval(self):

        """
        The aval function is used to determine the shape and dtype of a value.

        :param self: Refer to the object itself
        :return: The aval of the value

        """
        if isinstance(self.value, ImplicitArray):
            return jax.ShapedArray(self.value.shape, self.value.dtype)
        return get_aval(self.value)

    def full_lower(self):

        """
        The full_lower function is used to convert an expression into a form that can be
           evaluated by the SymPy lambdify function.  The full_lower function will recursively
           descend through the expression tree and replace any instances of ImplicitArray with
           their value attribute.  This allows for expressions like:

        :param self: Refer to the current object
        :return: An implicitarray object
        """
        if isinstance(self.value, ImplicitArray):
            return self

        return full_lower(self.value)

aval property

The aval function is used to determine the shape and dtype of a value.

Parameters:

Name Type Description Default
self

Refer to the object itself

required

Returns:

Type Description

The aval of the value

__init__(trace, value)

The init function is called when the class is instantiated. It sets up the object with all of its properties and methods. The self parameter refers to the instance of the object itself.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required
trace

Store the traceback object, which is used to print out a stack trace

required
value

Store the value of the exception

required

Returns:

Type Description

The value of the class

Source code in src/fjformer/xrapture/tracer.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def __init__(self, trace, value):

    """
    The __init__ function is called when the class is instantiated.
    It sets up the object with all of its properties and methods.
    The self parameter refers to the instance of the object itself.

    :param self: Refer to the instance of the class
    :param trace: Store the traceback object, which is used to print out a stack trace
    :param value: Store the value of the exception
    :return: The value of the class
    """
    super().__init__(trace)
    self.value = value

full_lower()

The full_lower function is used to convert an expression into a form that can be evaluated by the SymPy lambdify function. The full_lower function will recursively descend through the expression tree and replace any instances of ImplicitArray with their value attribute. This allows for expressions like:

Parameters:

Name Type Description Default
self

Refer to the current object

required

Returns:

Type Description

An implicitarray object

Source code in src/fjformer/xrapture/tracer.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def full_lower(self):

    """
    The full_lower function is used to convert an expression into a form that can be
       evaluated by the SymPy lambdify function.  The full_lower function will recursively
       descend through the expression tree and replace any instances of ImplicitArray with
       their value attribute.  This allows for expressions like:

    :param self: Refer to the current object
    :return: An implicitarray object
    """
    if isinstance(self.value, ImplicitArray):
        return self

    return full_lower(self.value)