前言

本文适用对象:任何接触过 TensorFlow, Pytorch, Keras 并且已经开始了解或尝鲜 Jax 的人群。如果是没有接触过任何深度学习框架的人群,这篇文章可能不适合你。在开始学习之前,你应该对 PyTorch 或 TensorFlow 有一定的了解。Jax 可能是一个比较难学的库,但值得一学。为什么使用 Jax 的理由这里就不多赘述了。就我个人而言,tf 或者 torch 自定义损失函数训练的速度实在是不太满意,即使这过程中换用了 numba 仍然差强人意,加之我是 tf 的旧党,所以不难预料的投入了 Jax 的怀抱。

正文开始

如果是经常用 Tensorflow Keras 或者 Pytorch Lightning 的炼丹师,一定会喜欢 fit 这个方法。所以本文以实现一个简单且常用的 fit 方法来快速上手 Jax,而且实现的这个 fit 方法基本上可以复用在很多项目中。另外再次强调,这篇文章可能不适合入门,但是很适合快速上工(从删库到跑路)。

本文实现的 fit 方法需要安装如下依赖,如果你已经使用过 Jax ,基本以下依赖库想必都已经了解了。

  • jax (jax, jaxlib)
  • flax 定义你的模型
  • optax 优化器,学习率,损失函数
  • orbax 用于保存 checkpoints
  • tqdm 显示进度条,用过多解释了
  • tensorboardX 一个三方的 tensorboard 的 python 库,用于输出一些训练过程日志

本文也是用的这个经典组合:Jax + Flax + Optax + Orbax,硬件加速 + 网络结构 + 损失函数 + 保存储存点

快速上手

先看看一个训练模型的模板,但只需要修改脚本中的三个关键代码部分。

import jax, flax, optax, orbax
from fit import lr_schedule, TrainState

# 准备你自己的数据集
train_ds, test_ds = your_dataset()
# 学习率
lr_fn = lr_schedule(
    base_lr=1e-3,
    steps_per_epoch=len(train_ds),
    epochs=100,
    warmup_epochs=5,
)

# key 1: 你的模型
model = YourModel()

# 初始化 key 和你的模型
key = jax.random.PRNGKey(0)
x = jnp.ones((1, 28, 28, 1)) # MNIST 示例输入大小
# 注意这里 train=True, 区别模型的训练和评价模式
var = model.init(key, x, train=True)
# 固定模板,直接复制就能用
state = TrainState.create(
    apply_fn=model.apply,
    params=var['params'],
    batch_stats=var['batch_stats'],
    tx=optax.inject_hyperparams(optax.adam)(lr_fn),
)

# 你的训练函数,详情参考下个章节
@jax.jit
def train_step():
    # key 2: 你的损失函数
    def loss_fn():
        ...
    return state, loss_dict, opt_state

# 你的评价函数
@jax.jit
def eval_step():
    # key 3: 你的评价函数
    ...
    return acc

# 一些必要的参数,epoches 之类
fit(state, train_ds, test_ds,
    train_step=train_step,
    eval_step=eval_step,
    eval_freq=1,
    num_epochs=10,
    log_name='mnist',
)

使用方法

让我们从一个简单的例子开始,在 MNIST 数据集上训练一个模型。首先,在训练脚本中导入 fit 模块。

from fit import *

在训练之前,你需要定义模型、损失函数和评估函数。让我们从模型开始。

模型

下面是一个非常简单的模型示例。setup 函数用来定义模型结构,__call__ 函数定义模型的前向传播。

class Model(nn.Module):
    def setup(self):
        self.conv1 = nn.Conv(features=16, kernel_size=(3, 3))
        self.dense1 = nn.Dense(features=10)

    # train=False 用于评价模式
    # 如果你使用了 dropout 或者 batch normalization 层
    # 我打赌你会用到它
    @nn.compact
    def __call__(self, x, train=False):
        # 简单的 conv + bn + relu + 全连接层
        x = self.conv1(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        # dropout 层
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        # 展平
        x = x.reshape((x.shape[0], -1))
        x = self.dense1(x)
        return x

接下来,你只需要考虑两件事:损失函数和评估函数。下面的 train_step 函数是训练模型的一个通用模板。state 对象是一个基于 TrainState 的对象的改进,其中不仅包含了模型参数、 Batch 状态和其他必要信息。batch 对应的是输入数据,opt_state 是优化器状态。

不要担心面对这个复杂的 train_step 函数,它只是一个模板。你可以复制并粘贴到你的脚本中,只需修改 loss_fn 函数即可。

@jax.jit
def train_step(state: TrainState, batch, opt_state):
    x, y = batch
    def loss_fn(params):
        logits, updates = state.apply_fn({
            'params': params,
            'batch_stats': state.batch_stats
        }, x, train=True, mutable=['batch_stats'], rngs={'dropout': key})
        loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
        loss_dict = {'loss': loss}
        return loss, (loss_dict, updates)

    # gradient and update
    (_, (loss_dict, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads, batch_stats=updates['batch_stats'])
    # update optimizer state
    _, opt_state = state.tx.update(grads, opt_state)
    return state, loss_dict, opt_state

损失函数 is all you need

让我们把重点放在 loss_fn 函数上。让我们从伪 pytorch 风格的代码开始,这对理解 Jax 中 train_step 方法里的 loss_fn 函数很有帮助。

def loss_fn():
    pred_y = model(x, train=False)
    loss = criterion(pred_y, true_y)
    return loss

很简单对吧?让我们继续。

def loss_fn(params):
    pred_y, updates = state.apply_fn({'params': params}, x, train=True)
    loss = criterion(pred_y, y_true)
    # 方便将一些你需要的日志记录展示在 tensorboard 里
    loss_dict = {'loss': loss}
    return loss, (loss_dict, updates)

接下来,让我们为 loss_fn 函数添加更多细节,例如 batch state 和 dropout key。这是 train_step 函数中 loss_fn 函数的完整版本。

def loss_fn(params):
    logits, updates = state.apply_fn({
        'params': params,
        'batch_stats': state.batch_stats
    }, x, train=True, mutable=['batch_stats'], rngs={'dropout': key})
    loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
    loss_dict = {'loss': loss}
    return loss, (loss_dict, updates)

既然这篇文章标题叫上工指南,mnist 这个例子显然就不太合适了,毕竟谁家上工不是得整一堆损失函数辅助配合了用的,所以需要在 jax.value_and_grad 函数中开启 has_aux=True 然后写成以下这个样子,并且为了将 log 输出到 tensorboard 里,所以这里用一个字典返回。

@jax.jit
def train_step(state: TrainState, batch, opt_state):
    x, y = batch
    def loss_fn(params):
        logits, updates = state.apply_fn({
            'params': params,
            'batch_stats': state.batch_stats
        }, x, train=True, mutable=['batch_stats'], rngs={'dropout': key})
        loss_1 = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
        loss_2 = jnp.mean(jnp.square(logits - y))
        loss = loss_1 + loss_2
        loss_dict = {'loss': loss, 'loss_1': loss_1, 'loss_2': loss_2}
        return loss, (loss_dict, updates)

    # gradient and update
    (_, (loss_dict, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads, batch_stats=updates['batch_stats'])
    # update optimizer state
    _, opt_state = state.tx.update(grads, opt_state)
    return state, loss_dict, opt_state

需要注意的是,损失函数应返回总损失值和要记录到 tensorboard 的字典。

评价函数

现在,让我们继续使用伪 pytorch 风格代码来看看评估函数。

def eval_step():
    true_x, true_y = data
    model.eval()
    pred_y = model(true_x)
    # 你的计算准确度的函数
    acc = metric(pred_y, true_y)
    return acc

在 pytorch 中,可以使用 model.eval() 函数将模型切换到评估模式。因为在训练和评估模式中,BN 层和 Dropout 层的行为不同。在 Jax 中,你需要在 apply_fn 函数中设置 train=False 参数。需要注意的是,如果使用 BN 层和Dropout 层,模型结构在训练和评估模式下应该是不同的,请参阅模型部分的 __call__ 函数。

与 train_step 函数类似,只需要传入 state 对象和 batch 对象。

@jax.jit
def eval_step(state: TrainState, batch):
    x, y = batch
    logits = state.apply_fn({
        'params': state.params,
        'batch_stats': state.batch_stats,
        }, x, train=False)
    acc = jnp.equal(jnp.argmax(logits, -1), y).mean()
    return acc

数据准备

TensorFlow Datasets

ds = tfds.load("mnist", split="train", as_supervised=True)
train_ds = ds.take(50000).map(lambda x, y: (x / 255, y))

Torchvision Datasets

ds = torchvision.datasets.MNIST(
    root="data", train=True, download=True,
    transform=torchvision.transforms.ToTensor()
)
train_ds = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True)

学习率

顺便说一下,lr_schedule 用于创建学习率函数,这是 TrainState 对象所必需的。当然,你也可以配置你偏好的 lr_schedule 或者直接用默认的 lr_schedule 。

lr_fn = lr_schedule(base_lr=1e-3,
    steps_per_epoch=len(train_ds),
    epochs=100,
    warmup_epochs=5,
)

此外,你还可以定义自己的链式更新,详情请查看 optax 库。

state = TrainState.create(
    apply_fn=model.apply,
    params=var['params'],
    batch_stats=var['batch_stats'],
    # 链式组合
    tx=optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(lr_fn)),
)

最后调用fit函数开始训练。

fit(state, train_ds, test_ds,
    train_step=train_step,
    eval_step=eval_step,
    # evaluate the model every N epochs (default 1)
    eval_freq=1,
    num_epochs=10,
    # log name for tensorboard
    log_name='mnist',
)

可视化训练过程

可以打开 Tensorboard 查看训练过程或检查任何损失和准确度指标。

Q&A

什么是 @jax.jit 装饰器?

这是一个将函数编译为单个静态函数的装饰器,可以在 GPU 或 TPU 上执行,如果你想加快训练过程,尤其是你自己的损失函数和评估函数,可以添加 @jax.jit 装饰器。

什么是 batch state 和 dropout key

Batch State 用于存储批处理归一化统计数据,而 Dropout Key 用于生成 Dropout 层的随机掩码。

完整的代码在我的 github 上⬇️,欢迎 fork 和 star。

Jax.fit