代码:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
def selu(x, alpha=1.65, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000005,))
%timeit selu(x).block_until_ready()
运行结果:
再次运行:
修改array的shape:
代码:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
def selu(x, alpha=1.65, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000003,))
%timeit selu(x).block_until_ready()
运行结果:
再次运行:
PS. 由此可以看出,jax对单步运行其实也是使用缓存操作的,对单步操作也可以通过缓存来进行多次调用的速度提升的。