So lets do a quick experiment with IPython's %timeit
. I defined two functions, one with the tf.function
decorator, one without :
import tensorflow as tf
def test_no_tracing(inp):
with tf.GradientTape(persistent=True) as tape:
tape.watch(inp)
out = inp**2
out = tf.unstack(out, axis=-1)
grad = []
for x in out:
grad.append(tape.gradient(x, inp))
del tape
return tf.stack(grad, axis=-1)
@tf.function
def test_tracing(inp):
print("Tracing")
with tf.GradientTape(persistent=True) as tape:
tape.watch(inp)
out = inp**2
out = tf.unstack(out, axis=-1)
grad = []
for x in out:
grad.append(tape.gradient(x, inp))
del tape
return tf.stack(grad, axis=-1)
inp = tf.random.normal((32, 100, 50))
Let's see the results :
With the tf.function
decorator:
In [2]: %timeit test_tracing(inp)
Tracing
2021-01-22 15:22:15.003262: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-01-22 15:22:15.076448: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2599990000 Hz
10.3 ms ± 579 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
And without :
In [3]: %timeit test_no_tracing(inp)
71.7 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The function decorated with tf.function
is roughly 7times faster. It can seem slower if you run the function only one time, because the decorated function has the overhead of the tracing, converting the code into a graph. Once that tracing is done, the code is much quicker.
This can be verified by running the function only one time, when it has not been traced already. We can do that by telling %timeit
to do only one loop and one repetition:
In [2]: %timeit -r 1 -n 1 test_tracing(inp)
Tracing
2021-01-22 15:29:47.189850: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-01-22 15:29:47.284413: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2599990000 Hz
4.97 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Here, the time is much bigger, closer to what you report in your question. But once this is done, the traced function is a lot quicker! Let's do it again:
In [3]: %timeit -r 1 -n 1 test_tracing(inp)
29.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
You can read more about how achieving better performances with tf.function
in the guide: Better performance with tf.function