Warning
This page was created from a pull request.
jax.profiler.trace_functionΒΆ
-
jax.profiler.
trace_function
(func, name=None, **kwargs)[source]ΒΆ Decorator that generates a trace event for the execution of a function.
For example:
>>> import jax, jax.numpy as jnp >>> >>> @jax.profiler.trace_function >>> def f(x): ... return jnp.dot(x, x.T).block_until_ready() >>> >>> f(jnp.ones((1000, 1000))
This will cause an βfβ event to show up on the trace timeline if the function execution occurs while the process is being traced by TensorBoard.
Arguments can be passed to the decorator via
functools.partial()
.>>> import jax, jax.numpy as jnp >>> from functools import partial >>> >>> @partial(jax.profiler.trace_function, name="event_name") >>> def f(x): ... return jnp.dot(x, x.T).block_until_ready() >>> >>> f(jnp.ones((1000, 1000))