Skip to content

utils.utils

Timer

Source code in src/python/easydel/utils/utils.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Timer:

    def __init__(self, name):
        """
        The __init__ function is called when the class is instantiated.
        It sets up the object with a name and initializes other variables.

        :param self: Represent the instance of the class
        :param name: Give the timer a name
        :return: An instance of the class

        """
        self.name_ = name
        self.elapsed_ = 0.0
        self.started_ = False
        self.start_time = time.time()

    def start(self):
        """
        The start function starts the timer.
                Args:
                    None

        :param self: Access the attributes and methods of the class in python
        :return: Nothing

        """
        assert not self.started_, "timer has already been started"
        self.start_time = time.time()
        self.started_ = True

    def stop(self):
        """
        The stop function stops the timer and adds the time elapsed since start was called to the total elapsed time.


        :param self: Represent the instance of the class
        :return: The time elapsed since the start function was called

        """
        assert self.started_, "timer is not started"
        self.elapsed_ += time.time() - self.start_time
        self.started_ = False

    def reset(self):
        """
        The reset function sets the elapsed time to 0.0 and the started flag to False.

        :param self: Represent the instance of the class
        :return: True if the timer was running, false otherwise

        """
        self.elapsed_ = 0.0
        self.started_ = False

    def elapsed(self, reset=True):
        """
        The elapsed function returns the elapsed time in seconds since the timer was started.
        If reset is True, then it also resets the timer to zero and restarts it.
        If reset is False, then it leaves the timer running.

        :param self: Represent the instance of the class
        :param reset: Reset the timer
        :return: The elapsed time in seconds

        """
        started_ = self.started_
        if self.started_:
            self.stop()
        elapsed_ = self.elapsed_
        if reset:
            self.reset()
        if started_:
            self.start()
        return elapsed_

__init__(name)

The init function is called when the class is instantiated. It sets up the object with a name and initializes other variables.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
name

Give the timer a name

required

Returns:

Type Description

An instance of the class

Source code in src/python/easydel/utils/utils.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(self, name):
    """
    The __init__ function is called when the class is instantiated.
    It sets up the object with a name and initializes other variables.

    :param self: Represent the instance of the class
    :param name: Give the timer a name
    :return: An instance of the class

    """
    self.name_ = name
    self.elapsed_ = 0.0
    self.started_ = False
    self.start_time = time.time()

elapsed(reset=True)

The elapsed function returns the elapsed time in seconds since the timer was started. If reset is True, then it also resets the timer to zero and restarts it. If reset is False, then it leaves the timer running.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
reset

Reset the timer

True

Returns:

Type Description

The elapsed time in seconds

Source code in src/python/easydel/utils/utils.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def elapsed(self, reset=True):
    """
    The elapsed function returns the elapsed time in seconds since the timer was started.
    If reset is True, then it also resets the timer to zero and restarts it.
    If reset is False, then it leaves the timer running.

    :param self: Represent the instance of the class
    :param reset: Reset the timer
    :return: The elapsed time in seconds

    """
    started_ = self.started_
    if self.started_:
        self.stop()
    elapsed_ = self.elapsed_
    if reset:
        self.reset()
    if started_:
        self.start()
    return elapsed_

reset()

The reset function sets the elapsed time to 0.0 and the started flag to False.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description

True if the timer was running, false otherwise

Source code in src/python/easydel/utils/utils.py
58
59
60
61
62
63
64
65
66
67
def reset(self):
    """
    The reset function sets the elapsed time to 0.0 and the started flag to False.

    :param self: Represent the instance of the class
    :return: True if the timer was running, false otherwise

    """
    self.elapsed_ = 0.0
    self.started_ = False

start()

The start function starts the timer. Args: None

Parameters:

Name Type Description Default
self

Access the attributes and methods of the class in python

required

Returns:

Type Description

Nothing

Source code in src/python/easydel/utils/utils.py
31
32
33
34
35
36
37
38
39
40
41
42
43
def start(self):
    """
    The start function starts the timer.
            Args:
                None

    :param self: Access the attributes and methods of the class in python
    :return: Nothing

    """
    assert not self.started_, "timer has already been started"
    self.start_time = time.time()
    self.started_ = True

stop()

The stop function stops the timer and adds the time elapsed since start was called to the total elapsed time.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description

The time elapsed since the start function was called

Source code in src/python/easydel/utils/utils.py
45
46
47
48
49
50
51
52
53
54
55
56
def stop(self):
    """
    The stop function stops the timer and adds the time elapsed since start was called to the total elapsed time.


    :param self: Represent the instance of the class
    :return: The time elapsed since the start function was called

    """
    assert self.started_, "timer is not started"
    self.elapsed_ += time.time() - self.start_time
    self.started_ = False

Timers

Group of timers.

Source code in src/python/easydel/utils/utils.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class Timers:
    """Group of timers."""

    def __init__(self, use_wandb, tensorboard_writer):
        self.timers = {}
        self.use_wandb = use_wandb
        self.tensorboard_writer = tensorboard_writer

    def __call__(self, name):
        if name not in self.timers:
            self.timers[name] = Timer(name)
        return self.timers[name]

    def write(self, names, iteration, normalizer=1.0, reset=False):

        """
        The write function is used to write the elapsed time of a timer to Tensorboard and/or Weights & Biases.

        :param self: Make the function a method of the class
        :param names: Specify which timer(s) to write
        :param iteration: Keep track of the number of iterations
        :param normalizer: Normalize the time elapsed by a certain value
        :param reset: Reset the timer after it has been written to tensorboard
        :return: Nothing

        """
        assert normalizer > 0.0
        for name in names:
            value = self.timers[name].elapsed(reset=reset) / normalizer

            if self.tensorboard_writer:
                self.tensorboard_writer.add_scalar(f"timers/{name}", value, iteration)

            if self.use_wandb:
                wandb.log({f"timers/{name}": value}, step=iteration)

    def log(self, names, normalizer=1.0, reset=True):
        """
        The log function is used to print the time elapsed for a given function.

        :param self: Represent the instance of the class
        :param names: Specify the name of the timer that we want to log
        :param normalizer: Normalize the time taken to run a function
        :param reset: Reset the timer after logging
        :return: The time taken for the given name

        """
        assert normalizer > 0.0

        if isinstance(names, str):
            names = [names]
        for name in names:
            elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
            termcolor.cprint(
                f"Time Took to Complete Task {name} (microseconds) : "
                f"{termcolor.colored(elapsed_time, color='white', force_color=True)}",
                color="cyan",
                force_color=True
            )

log(names, normalizer=1.0, reset=True)

The log function is used to print the time elapsed for a given function.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
names

Specify the name of the timer that we want to log

required
normalizer

Normalize the time taken to run a function

1.0
reset

Reset the timer after logging

True

Returns:

Type Description

The time taken for the given name

Source code in src/python/easydel/utils/utils.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def log(self, names, normalizer=1.0, reset=True):
    """
    The log function is used to print the time elapsed for a given function.

    :param self: Represent the instance of the class
    :param names: Specify the name of the timer that we want to log
    :param normalizer: Normalize the time taken to run a function
    :param reset: Reset the timer after logging
    :return: The time taken for the given name

    """
    assert normalizer > 0.0

    if isinstance(names, str):
        names = [names]
    for name in names:
        elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
        termcolor.cprint(
            f"Time Took to Complete Task {name} (microseconds) : "
            f"{termcolor.colored(elapsed_time, color='white', force_color=True)}",
            color="cyan",
            force_color=True
        )

write(names, iteration, normalizer=1.0, reset=False)

The write function is used to write the elapsed time of a timer to Tensorboard and/or Weights & Biases.

Parameters:

Name Type Description Default
self

Make the function a method of the class

required
names

Specify which timer(s) to write

required
iteration

Keep track of the number of iterations

required
normalizer

Normalize the time elapsed by a certain value

1.0
reset

Reset the timer after it has been written to tensorboard

False

Returns:

Type Description

Nothing

Source code in src/python/easydel/utils/utils.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def write(self, names, iteration, normalizer=1.0, reset=False):

    """
    The write function is used to write the elapsed time of a timer to Tensorboard and/or Weights & Biases.

    :param self: Make the function a method of the class
    :param names: Specify which timer(s) to write
    :param iteration: Keep track of the number of iterations
    :param normalizer: Normalize the time elapsed by a certain value
    :param reset: Reset the timer after it has been written to tensorboard
    :return: Nothing

    """
    assert normalizer > 0.0
    for name in names:
        value = self.timers[name].elapsed(reset=reset) / normalizer

        if self.tensorboard_writer:
            self.tensorboard_writer.add_scalar(f"timers/{name}", value, iteration)

        if self.use_wandb:
            wandb.log({f"timers/{name}": value}, step=iteration)

get_mesh(shape=(1, -1, 1, 1), axis_names=('dp', 'fsdp', 'tp', 'sp'))

The get_mesh function is a helper function that creates a JAX Mesh object.

Parameters:

Name Type Description Default
shape Sequence[int]

typing.Sequence[int]: Specify the shape of the array that is used to create the mesh

(1, -1, 1, 1)
axis_names Sequence[str]

typing.Sequence[int]: Specify the Axis Names in mesh

('dp', 'fsdp', 'tp', 'sp')

Returns:

Type Description

A mesh object

Source code in src/python/easydel/utils/utils.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def get_mesh(
        shape: typing.Sequence[int] = (1, -1, 1, 1),
        axis_names: typing.Sequence[str] = ("dp", "fsdp", "tp", "sp")
):
    """
    The get_mesh function is a helper function that creates a JAX Mesh object.

    :param shape: typing.Sequence[int]: Specify the shape of the array that is used to create the mesh
    :param axis_names: typing.Sequence[int]: Specify the Axis Names in mesh
    :return: A mesh object

    """
    from jax.sharding import Mesh
    from jax.experimental import mesh_utils
    array = jnp.ones((len(jax.devices()), 1)).reshape(shape)
    return Mesh(mesh_utils.create_device_mesh(array.shape), axis_names)