Warning
This page was created from a pull request.
jax.experimental.host_callback module¶
Primitives for calling from accelerators to Python functions on the host.
Experimental: please give feedback, and expect changes.
This module introduces the host callback functions id_tap()
and
id_print()
, which behave like the identity function but have the
side-effect of sending the arguments from the device to the host and
invoking a user-specified Python function (for id_tap()
) or printing the
arguments on the host (for id_print()
). The Python function passed
to id_tap()
takes two positional arguments (the value tapped from the
device computation along with transforms
sequence, described below).
A few examples:
# calls func(2x, []) on host and returns 2x
y = id_tap(func, 2 * x)
# calls func((2x, 3x), []) and returns (2x, 3x)
y, z = id_tap(func, (2 * x, 3 * x)) # The argument can be a pytree
# calls func(2x, []) and returns y
y = id_tap(func, 2 * x, result=y) # override the result of id_tap
# calls func(2x, [], what='activation') and returns 2x
y = id_tap(functools.partial(func, what='activation'), 2 * x)
# calls func(dict(x=x, y=y), what='data') and returns dict(x=x, y=y)
x, y = id_tap(lambda tap, transforms: func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use id_print()
instead, with
the difference that id_print()
takes one positional argument (to print
on the host), the optional kwarg result
, and possibly additional kwargs
that are also printed along with the automatic kwarg transforms
.
The order of execution of the tap functions is constrained by data dependency: the arguments are tapped after all the arguments are computed and before the result of the call is used. As of September 2020, it is not necessary anymore for the results of the tap to be used in the rest of the computation. The tap function will execute based on program order. The host tap functions will be executed for each device in the order in which the send operations were performed on the device.
The host tap functions for multiple devices may be interleaved. The data from the devices is received by separate threads managed by the JAX runtime (one thread per device). The runtime maintains a buffer of configurable size. When the buffer is full, all the receiving threads are paused which eventually pauses the computation on devices. The runtime has one additional thread that invokes the Python user functions with the received data. If the processing of the callbacks is slow, it may actually lead to the runtime buffer filling up, and eventually pausing the computation on the devices when they need to send something. For more details on the outfeed receiver runtime mechanism see runtime code.
Exceptions from the user-defined tap functions are logged along with their stack traces, but the receiving threads are not stopped.
In order to pause the execution until all data from computations already
started on devices has arrived and has been processed, use barrier_wait()
.
This will also raise TapFunctionException
if any exception had occurred
in one of the tap functions.
The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, and it is practically impossible to use it under conditionals or in loops of non-constant iteration count. Furthermore, code that uses the outfeed mechanism directly cannot be transformed by JAX. All these limitations are addressed by the host callback functions. The tapping API introduced here makes it easy to share the outfeed mechanism for multiple purposes, while supporting all transformations.
Note that after you have used the host callback functions, you cannot
use lax.outfeed directly. You may want to stop_outfeed_receiver()
if you later need to use lax.outfeed.
We describe the behaviour under transformations in the context of the following function definition:
def power3(x):
y = x * x
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
return y * x
power3(3.)
# what: x,x^2 : [3., 9.]
During JAX transformations the special parameter transforms
is added to
contain a list of transformation descriptors in the form
(transform_name, transform_params)
.
For jax.vmap()
the arguments are batched, and transforms
is extended
with transformation name batch
and batch_dims
set to the the tuple of
batched dimensions (one entry per argument, None
denotes an argument that
was broadcast):
jax.vmap(power3)(np.arange(3.))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
4]]
For jax.jvp()
there will be two callbacks, one with the values of
the primals and one with the tangents:
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2: [3., 9.]
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]
For jax.vjp()
or jax.grad()
there will be one callback with the
values of the adjoints for the arguments. You may also see a callback with
the values of the primals from the forward pass, if those values are needed for
the backward pass:
jax.grad(power3)(3.)
# what=x,x^2: [3., 9.] # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.] # from backward pass, adjoints of _, y
See documentation for id_tap()
and id_print()
.
For more usage example, see tests/host_callback_test.py.
- Still to do:
Performance tests.
Add flags for logging.
Add unit tests with mocks.
Explore a simpler API that uses Python program-order, instead of data dependency-order.
Explore implementation with outside compilation.
Explore an extended API that allows the host function to return values to the accelerator computation.
Low-level details and debugging¶
The C++ receiver
is started automatically on the first call to id_tap()
. In order to stop
it properly, upon start an atexit
handler is registered to call
barrier_wait()
with the logging name “at_exit”.
There are a few environment variables that you can use to turn on logging for the C++ outfeed receiver backend.
TF_CPP_MIN_LOG_LEVEL=0
: will turn on INFO logging, needed for all below.
TF_CPP_MIN_VLOG_LEVEL=3
: will turn make all VLOG logging up to level 3 behave like INFO logs. This may be too much, but you will see which modules are logging relevant info, and then you can select which modules to log from:TF_CPP_VMODULE=<module_name>=3`
You should also use the --verbosity=2
flag so that you see the logs from Python.
For example:
`
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,cpu_transfer_manager=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackTest.test_jit_simple
`
API¶
-
jax.experimental.host_callback.
id_tap
(tap_func, arg, *, result=None, **kwargs)[source]¶ Host-callback tap primitive, like identity function with a call to
tap_func
.Experimental: please give feedback, and expect changes!
id_tap
behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.- Parameters
tap_func – tap function to call like
tap_func(arg, transforms)
, witharg
as described below and wheretransforms
is the sequence of applied JAX transformations in the form(name, params)
.arg – the argument passed to the tap function, can be a pytree of JAX types.
result – if given, specifies the return value of
id_tap
. This value is not passed to the tap function, and in fact is not sent from the device to the host. If theresult
parameter is not specified then the return value ofid_tap
isarg
.
- Returns
arg
, orresult
if given.
The order of execution is by data dependency: after all the arguments and the value of
result
if present, are computed and before the returned value is used. At least one of the returned values ofid_tap
must be used in the rest of the computation, or else this operation has no effect.If you want to tap a constant value, you should use the
result
parameter to control when it is tapped, otherwise it will be tapped during tracing of the function:x = id_tap(42, result=x)
Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in
outfeed_receiver()
.For more details see the module documentation.
-
jax.experimental.host_callback.
id_print
(arg, *, result=None, output_stream=None, threshold=None, **kwargs)[source]¶ Like
id_tap()
with a printing tap function.Experimental: please give feedback, and expect changes!
On each invocation of the printing tap, the
kwargs
if present will be printed first (sorted by keys). Then arg will be printed, with the arrays stringified withnumpy.array2string
.See the
id_tap()
documentation.Additional keyword arguments:
output_stream
if given then it will be used instead of the built-inprint
. The string will be passed asoutput_stream.write(s)
.threshold
is passed tonumpy.array2string
.
-
jax.experimental.host_callback.
outfeed_receiver
()[source]¶ Implements a barrier after a block of code.
DEPRECATED: This function is not necessary anymore, it is here for backwards compatiblity. At the moment it implements a
barrier_wait
after the body of the context manager finishes.
-
exception
jax.experimental.host_callback.
TapFunctionException
[source]¶ Signals that some tap function had exceptions.
Raised by
outfeed_receiver()
.