代码:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

def selu(x, alpha=1.65, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000005,))
%timeit selu(x).block_until_ready()


运行结果:

jax中对单步操作的缓存对性能造成的影响_单步操作


再次运行:

jax中对单步操作的缓存对性能造成的影响_单步操作_02



修改array的shape:

代码:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

def selu(x, alpha=1.65, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000003,))
%timeit selu(x).block_until_ready()


运行结果:

jax中对单步操作的缓存对性能造成的影响_单步操作_03


再次运行:

jax中对单步操作的缓存对性能造成的影响_缓存_04






PS. 由此可以看出,jax对单步运行其实也是使用缓存操作的,对单步操作也可以通过缓存来进行多次调用的速度提升的。