这个转换的问题其实是贫僧在尝试将某个.txt
文件转换成Tensor来喂给训练好的神经网络模型时遇到的(训练的神经网络是char level的,具体看贫僧之前的博文)时遇到的。实现的步骤分成以下几个部分:
- 读取txt文件内容
- 将txt内容按照字典转化成对应的数字
- 将文件保存为.t7格式,方便神经网络读取
在正式开始之前先说下字典,字典是用这种方式生成的:
alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} "
dict = {}
for i = 1,#alphabet do
dict[alphabet:sub(i,i)] = i
end
alphabet_size = #alphabet
生成后的字典里面每一个字符都对应这一个数字。
读取txt文件内容
这部分的实现很简单,因为Lua已经实现了这个功能,直接调用就行了:
--[[
函数名:read_files
输入:txt文件路径
输出:txt文件内容,string类型
]]
function read_files(filename)
local f = assert(io.open(filename, 'r'))
local content = f:read('*all')
f:close()
return content
end
上面这个函数最后会返回txt文件的内容,因为比较简单所以就不细讲了。
将txt内容按照字典转化成对应的数字
这个地方有两个难点:
- Lua不像Python可以直接用for来遍历字符串,所以要用另一种方法来遍历字符串
- 按照字典转换
先上代码:
--[[
函数名:translate_char
输入:文件名,是否保存(true或者false)
输出:按照字典翻译之后的tensor
备注:其实可以改一改就可以变成word型的,默认是只接受一个文件,然后重复60次内容,如果需要接收60个不同文件的内容要另外改
]]
function translate_char(filename, save_translation)
local m = torch.Tensor(60, 201, 10):zero()
local content = read_files(filename)
local tmp_i = 0
local j = 1
for i = 1, #content do
if content:sub(i, i) == '\n' then -- 分行
j = j + 1
tmp_i = i
else
if j > 10 then -- 为了保证后面输入进矩阵的时候不会越界,而且也有助于看哪个文档里面的caption多于10个
print('error! j is %d, file name is %s', j, filename)
break
end
if dict[content:sub(i, i)] ~= nil then -- 避免文档中出现了字典中没有的特殊字符
if i - tmp_i < 201 then
for k = 1, 60 do
m[{k, i - tmp_i, j}] = dict[content:sub(i, i)]
end
-- print(j) -- 这是用来调试的代码,用来检查行数是不是正确的
end
end
end
end
if save_translation then
torch.save(string.gsub(filename, '.txt', '.t7'), m)
end
return m
end
上面多了一些次要的东西,例如local m = torch.Tensor(60, 201, 10):zero()
,这行是设置零矩阵,这是根据论文的要求来做的,因为默认内容长度不够的地方会用0来补充,而内容长度超过201的时候才会忽略掉后面的内容(所以有这句if i - tmp_i < 201 then
)。核心的替换代码其实是m[{k, i - tmp_i, j}] = dict[content:sub(i, i)]
。而遍历字符串用到的语句就是content:sub(i, i)
,要搭配for循环来用。
先细讲一下遍历字符串,sub(j, k)
这个函数其实就是截取第j
到k
位的字符(注意,Lua里面字符从1
开始计数),所以如果是sub(i, i)
的话就会提取第i
位的字符,因此Lua里面遍历字符串要这样做:
for i = 1, #content do
character = content:sub(i, i)
-- 补充对单个字母的操作
end
而查字典的话其实就是将字母输入到字典的索引里面(其实Lua就只有table一个类型,但是贫僧更加熟悉python,所以用了“字典”这个词)。
m[{k, i - tmp_i, j}] = dict[content:sub(i, i)]
结合上面的遍历字符串操作就可以很容易理解这里是在做什么了。