JAX介绍和快速入门示例

来源:DeepHub IMBA 本文约3300字,建议阅读10+分钟
本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念。
它可以被视为 GPU 和 TPU 上运行的NumPy , jax.numpy提供了与numpy非常相似API接口。 它与 NumPy API 非常相似,几乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。 由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。 它对自动微分有很好的支持,对机器学习研究很有用。可以使用 jax.grad() 触发自动区分。 JAX 鼓励函数式编程,因为它是面向函数的。与 NumPy 数组不同,JAX 数组始终是不可变的。 JAX提供了一些在编写数字处理时非常有用的程序转换,例如JIT . JAX()用于JIT编译和加速代码,JIT .grad()用于求导,以及JIT .vmap()用于自动向量化或批处理。 JAX 可以进行异步调度。所以需要调用 .block_until_ready() 以确保计算已经实际发生。 
自动:在执行 JAX 函数的库调用时,默认情况下 JIT 编译会在后台进行。 手动:您可以使用 jax.jit() 手动请求对自己的 Python 函数进行 JIT 编译。 

JAX 使用示例
pip install jax
import jaximport jax.numpy as jnpfrom jax import randomfrom jax import grad, jitimport numpy as npkey = random.PRNGKey(0)
# runs on CPU - numpysize = 5000x = np.random.normal(size=(size, size)).astype(np.float32)%timeit np.dot(x, x.T)# 1 loop, best of 5: 1.61 s per loop
# runs on CPU - JAXsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%timeit jnp.dot(x, x.T).block_until_ready()# 1 loop, best of 5: 3.49 s per loop
# runs on GPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 102 µs, sys: 42 µs, total: 144 µs# Wall time: 155 µs# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s# Wall time: 2.16 s# 3. 10 loops, best of 5: 68.9 ms per loop
设备传输时间:将矩阵传输到 GPU 所经过的时间。耗时 0.155 毫秒。编译时间:JIT 编译经过的时间。耗时 2.16 秒。运行时间:有效的代码运行时间。耗时 68.9 毫秒。
# runs on TPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 131 µs, sys: 72 µs, total: 203 µs# Wall time: 164 µs# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms# Wall time: 837 ms# 3. 100 loops, best of 5: 16.5 ms per loop
import jax.tools.colab_tpujax.tools.colab_tpu.setup_tpu()
XLA
def selu_np(x, alpha=1.67, lmbda=1.05):return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)def selu_jax(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# runs on the CPU - numpyx = np.random.normal(size=(1000000,)).astype(np.float32)%timeit selu_np(x)# 100 loops, best of 5: 7.6 ms per loop
# runs on the CPU - JAXx = random.normal(key, (1000000,))%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms# Wall time: 124 ms# 2. 100 loops, best of 5: 4.8 ms per loop
# runs on the GPUx = random.normal(key, (1000000,))%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 103 µs, sys: 0 ns, total: 103 µs# Wall time: 109 µs# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms# Wall time: 447 ms# 3. 1000 loops, best of 5: 1.21 ms per loop
# runs on the GPUx = random.normal(key, (1000000,))selu_jax_jit = jit(selu_jax)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 70 µs, sys: 28 µs, total: 98 µs# Wall time: 104 µs# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms# Wall time: 122 ms# 3. 10000 loops, best of 5: 130 µs per loop
CPU 上的 NumPy:7.6 毫秒。 CPU 上的 JAX:4.8 毫秒(x1.58 加速)。 没有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 加速)。 带有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 加速)。 
@jitdef selu_jax_jit(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
使用 jax.grad 自动微分
def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(3.)derivative_fn = grad(sum_logistic)print(derivative_fn(x_small))# [0.25, 0.19661197, 0.10499357]
总结
编辑:黄继彦
评论
