使用Pallas为jax编写kernel扩展,需要使用JAX-Triton扩展包。由于Google的深度学习框架Jax主要是面向自己的TPU进行开发的,虽然也同时支持NVIDIA的GPU,但是支持力度有限,目前JAX-Triton只能在TPU设备上正常运行,无法保证在GPU上正常运行。


该结果使用kaggle上的TPU和GPU进行测试获得。


测试时间:

2024-01-18 21:12:09 星期四



Google的Jax框架的JAX-Triton目前只能成功运行在TPU设备上(使用Pallas为jax编写kernel扩展)——  GPU上目前无法正常运行,目前正处于 experimental 阶段_测试时间

Google的Jax框架的JAX-Triton目前只能成功运行在TPU设备上(使用Pallas为jax编写kernel扩展)——  GPU上目前无法正常运行,目前正处于 experimental 阶段_测试时间_02