官方地址:
https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad
这里只给出几个样例代码:
- 设置 allow_int 参数,实现对整数类型求导:
未对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, ))
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
正常运行:
对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, 1))
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
报错:
通过设置 allow_int 实现对整数类型求导:
import jax
def fun(x, y):
print(x, y)
return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])
fun_grad = jax.grad(fun, argnums=(0, 1), allow_int=True)
x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]
print( fun_grad(x, y) )
未报错运行,但是没有获得争取结果:
应该这么说,在jax中不能对整数类型求导的,虽然这里设置了 allow_int 但是也不能得到正确的对整数类型的求导。