官方地址:

https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad


jax框架:jax.grad_html



这里只给出几个样例代码:

  1. 设置 allow_int 参数,实现对整数类型求导:


未对整数类型求导:

import jax

def fun(x, y):
    print(x, y)
    return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])

fun_grad = jax.grad(fun, argnums=(0, ))

x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]

print( fun_grad(x, y) )

正常运行:

jax框架:jax.grad_正常运行_02


对整数类型求导:

import jax

def fun(x, y):
    print(x, y)
    return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])

fun_grad = jax.grad(fun, argnums=(0, 1))

x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]

print( fun_grad(x, y) )

报错:

jax框架:jax.grad_html_03


通过设置 allow_int 实现对整数类型求导:

import jax

def fun(x, y):
    print(x, y)
    return jax.numpy.sum(2*x[0] + y[0] + 2*x[1] + y[1])

fun_grad = jax.grad(fun, argnums=(0, 1), allow_int=True)

x = [jax.numpy.arange(0, 5).astype(jax.numpy.float32), jax.numpy.arange(1, 6).astype(jax.numpy.float32),]
y = [jax.numpy.arange(1, 6), jax.numpy.arange(2, 7),]

print( fun_grad(x, y) )

未报错运行,但是没有获得争取结果:

jax框架:jax.grad_html_04


应该这么说,在jax中不能对整数类型求导的,虽然这里设置了 allow_int 但是也不能得到正确的对整数类型的求导。