如何使用tensor2tensor自定义数据训练模型

由于tensor2tensor高度的封装,内部添加和一些数据集,和一些常见的问题,所以在直接用起来比较方便。但是如果想要用不同的数据训练模型,或者是用模型解决一个其他的问题,就要费一番功夫了。

这里主要是解决了用自己的数据集,使用tensor2tensor训练一个英中翻译模型,当然训练中英,只需要加上`_rev`即可。

如果要使用自己的数据集,根据github文档可知,它是没有告诉你怎么做滴,那么怎么办呢,就不做了么? 当然不可能。文档中提到了可以定义自己的问题,然后在里面可以定义一些内容,例如单词表大小,数据集的位置,分词方式等。嗯,看到了,是有数据集的位置的,那么直接定义一个问题不就可以么,然后在里面指定相应的数据集位置,那么看下代码:

完整代码在后面:

首先是要有一个自定义的用户目录,也就是参数‘--usr_dir ’ 的值。

接下来,创建一个 problem_name.py 文件,并且里面有__init__.py 这个文件,并且在init.py 中把problem_name 导入,这样才能够被`t2t-datagen`和`t2t-trainer`识别,并注册到t2t里面。就像下面这样。

mobilesam 自定义训练_自定义训练模型

在创建完文件之后就要对文件的内容进行编写了。

一些导入文件的代码略过(篇幅有限)

然后两个数据集:

_NC_TRAIN_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.tgz",
    ["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]

_NC_TEST_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
    ("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]

上面代码:重要的也就是这两个数据集了:其中一个是训练集, 一个是测试集,开发集程序内部会进行分割,这里就不考虑。

首先是列表内容元素的第一个链接指的是元素的位置,也就是网络位置,由于我们要是用的是本地的文件,这里就是一个僵尸文件,也就是一个虚拟地址+僵尸压缩文件。主要作用是避免内部生成单词表和数据的时候进行数据的下载。

后面一个"raw-train.zh-en.en", "raw-train.zh-en.zh" 也就是平行语料,也就是自己的数据集文件,这里面的文件只要是处理干净就行了,关于分词的话,谷歌内部的新的分词方式subword基本能满足使用,某些论文中甚至要优于bpe分词方式。

def create_dummy_tar(tmp_dir, dummy_file_name):
    dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
    if not os.path.exists(dummy_file_path):
        tf.logging.info("Generating dummy file: %s", dummy_file_path)
        tar_dummy = tarfile.open(dummy_file_path, "w:gz")
        tar_dummy.close()
    tf.logging.info("File %s is already exists or created", dummy_file_name)

上面函数主要是为了防止t2t的数据生成工具进行下载,而创建僵尸压缩文件。对于每一个数据集都会进行检查。

def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
        train = dataset_split == problem.DatasetSplit.TRAIN
        train_dataset = self.get_training_dataset(tmp_dir)
        datasets = train_dataset if train else _NC_TEST_DATASETS
        for item in datasets:
            dummy_file_name = item[0].split("/")[-1]
            create_dummy_tar(tmp_dir, dummy_file_name)
            s_file, t_file = item[1][0], item[1][1]
            if not os.path.exists(os.path.join(tmp_dir, s_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
            if not os.path.exists(os.path.join(tmp_dir, t_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)

        source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
        target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]

    ...

    return text_problems.text2text_generate_encoded(
            text_problems.text2text_txt_iterator(data_path + ".lang1",
                                                 data_path + ".lang2"),
            source_vocab, target_vocab)

上面函数主要是生成样本数据,也就是在data文件夹下面的一些数据,同样如果在data目录下面没有单词表文件的话,会根据数据集生成单词表文件。

至此,基本已经完成了所有操作,只需要用 t2t-datagen 和 t2t-trainer 生成数据并进行训练即可!

另外,提一下,自定义的类名应该是驼峰法命名,定义的问题对应根据驼峰规则用横线隔开,例如这里我定义的是:translate_enzh_sub32k,对应类名 TranslateEnzhSub32k。

~ ~

辣么,如何使用已经有了单词表,平行语料之后应该如何定义问题呢,见下一篇: 

完成代码:

# coding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import tokenizer
from tensor2tensor.utils import registry
import tensorflow as tf

from collections import defaultdict

_NC_TRAIN_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.tgz",
    ["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]

_NC_TEST_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
    ("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]

def create_dummy_tar(tmp_dir, dummy_file_name):
    dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
    if not os.path.exists(dummy_file_path):
        tf.logging.info("Generating dummy file: %s", dummy_file_path)
        tar_dummy = tarfile.open(dummy_file_path, "w:gz")
        tar_dummy.close()
    tf.logging.info("File %s is already exists or created", dummy_file_name)


def get_filename(dataset):
    return dataset[0][0].split("/")[-1]


@registry.register_problem
class TranslateEnzhSub32k(translate.TranslateProblem):
    """Problem spec for WMT En-De translation, BPE version."""

    # 设定单词表生成大小
    @property
    def vocab_size(self):
        return 32000

    # 使用 bpe 进行分词
    # @property
    # def vocab_type(self):
    #    return text_problems.VocabType.TOKEN

    # 超过单词表之后的词的表示,None 表示用元字符替换
    @property
    def oov_token(self):
        """Out of vocabulary token. Only for VocabType.TOKEN."""
        return None

    @property
    def approx_vocab_size(self):
        return 32000

    @property
    def source_vocab_name(self):
        return "vocab.enzh-sub-en.%d" % self.approx_vocab_size

    @property
    def target_vocab_name(self):
        return "vocab.enzh-sub-zh.%d" % self.approx_vocab_size

    def get_training_dataset(self, tmp_dir):
        full_dataset = _NC_TRAIN_DATASETS
        # 可以添加一些其他的数据集在这里
        return full_dataset

    def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
        train = dataset_split == problem.DatasetSplit.TRAIN
        train_dataset = self.get_training_dataset(tmp_dir)
        datasets = train_dataset if train else _NC_TEST_DATASETS
        for item in datasets:
            dummy_file_name = item[0].split("/")[-1]
            create_dummy_tar(tmp_dir, dummy_file_name)
            s_file, t_file = item[1][0], item[1][1]
            if not os.path.exists(os.path.join(tmp_dir, s_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
            if not os.path.exists(os.path.join(tmp_dir, t_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)

        source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
        target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
        source_vocab = generator_utils.get_or_generate_vocab(
            data_dir,
            tmp_dir,
            self.source_vocab_name,
            self.approx_vocab_size,
            source_datasets,
            file_byte_budget=1e8)
        target_vocab = generator_utils.get_or_generate_vocab(
            data_dir,
            tmp_dir,
            self.target_vocab_name,
            self.approx_vocab_size,
            target_datasets,
            file_byte_budget=1e8)
        tag = "train" if train else "dev"
        filename_base = "wmt_enzh_%sk_sub_%s" % (self.approx_vocab_size, tag)
        data_path = translate.compile_data(tmp_dir, datasets, filename_base)
        return text_problems.text2text_generate_encoded(
            text_problems.text2text_txt_iterator(data_path + ".lang1",
                                                 data_path + ".lang2"),
            source_vocab, target_vocab)


    def feature_encoders(self, data_dir):
        source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
        target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
        source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
        target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
        return {
            "inputs": source_token,
            "targets": target_token,
        }