Embedding理解
嵌入层将正整数(下标)转换为具有固定大小的向量 ------官网
词嵌入是一种语义空间到向量空间的映射,简单说就是把每个词语都转换为固定维数的向量,并且保证语义接近的两个词转化为向量后,这两个向量的相似度也高。
举例说明embedding过程:
“Could have done better”
- 通过索引对该句子进行编码,每个单词分配一个索引,上面的句子就会变成这样:
122 8 114 12
- 创建嵌入矩阵,即每一个索引需要分配多少维向量,也就是说每个词需要转化为多少维向量,这里设置为5,嵌入矩阵就会变成这样:
索引122对应的词向量:0.0190721 -0.04473796 0.03923314 0.04681129 -0.02183579
索引8对应的词向量:0.01421751 -0.00090249 0.01750712 -0.03774468 0.04996594
索引114对应的词向量:-0.04607415 0.04186441 0.02681447 -0.00218643 0.03448829
索引12对应的词向量:-0.03162882 0.03427991 0.0324514 0.03953638 0.01771886
下面引用苏神的一段话,我理解的是经常一起出现的词词向量数值会比较接近,也就是说经常一起共线的词语义是相似的,天呐,我是理解错了,还是这不在我的常识范围内呀,任重道远!!!
最后,解释一下为什么这些字词向量会有一些性质,比如向量的夹角余弦、向量的欧氏距离都能在一定程度上反应字词之间的相似性?这是因为,我们在用语言模型无监督训练时,是开了窗口的,通过前n个字预测下一个字的概率,这个n就是窗口的大小,同一个窗口内的词语,会有相似的更新,这些更新会累积,而具有相似模式的词语就会把这些相似更新累积到可观的程度。我举个例子,“忐”、“忑”这两个字,几乎是连在一起用的,更新“忐”的同时,几乎也会更新“忑”,因此它们的更新几乎都是相同的,这样“忐”、“忑”的字向量必然几乎是一样的。“相似的模式”指的是在特定的语言任务中,它们是可替换的,比如在一般的泛化语料中,“我喜欢你”中的“喜欢”,以及一般语境下的“喜欢”,替换为“讨厌”后还是一个成立的句子,因此“喜欢”与“讨厌”必然具有相似的词向量,但如果词向量是通过情感分类任务训练的,那么“喜欢”与“讨厌”就会有差异较大的词向量。-----苏神,来自参考资料2
Keras中embedding参数详解
from keras.layers.embeddings import Embedding
keras.layers.Embedding(input_dim, #词汇表大小,就是你的文本里你感兴趣词的数量
output_dim, #词向量的维度
embeddings_initializer='uniform',# Embedding矩阵的初始化方法
embeddings_regularizer=None,# Embedding matrix 的正则化方法
activity_regularizer=None,
embeddings_constraint=None, # Embedding matrix 的约束函数
mask_zero=False, #是否把 0 看作"padding" 值,取值为True时,接下来的所有层都必须支持 masking,词汇表的索引要从1开始(因为文档填充用的是0,如果词汇表索引从0开始会产生混淆,input_dim
=vocabulary + 1)
input_length=None)# 输入序列的长度,就是文档经过padding后的向量的长度。
'''
函数输入:尺寸为(batch_size, input_length)的2D张量,
batch_size就是你的mini batch里的样本量,
input_length就是你的文档转化成索引向量(每个词用词索引表示的向量)后的维数。
函数输出:尺寸为(batch_size, input_length,output_dim)的3D张量,
上面说了,output_dim就是词向量的维度,就是词转化为向量,这个向量的维度,
比如word2vec把“哈哈”转化为向量[1.01,2.9,3],那么output_dim就是3.
'''
代码案例说明
举例1:
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
@author:
@contact:
@time:
@context:
"""
from keras.layers.embeddings import Embedding
from keras.models import Sequential
import numpy as np
#我们随机生成第一层输入,即每个样本存储于单独的list,此list里的每个特征或者说元素用正整数索引表示,同时所有样本构成list
input_array = np.random.randint(1000, size=(32, 10))
'''
[[250 219 228 56 572 110 467 214 173 342]
[678 13 994 406 678 995 966 398 732 715]
...
[426 519 254 180 235 707 887 962 834 269]
[775 380 706 784 842 369 514 265 797 976]
[666 832 821 953 369 836 656 808 562 263]]
'''
model = Sequential()
model.add(Embedding(1000, 64, input_length=10))#词汇表里词999,词向量的维度64,输入序列的长度10
# keras.layers.Embedding(input_dim, output_dim, input_length)#词汇表大小,词向量的维度,输入序列的长度
print(model.input_shape)
print(model.output_shape)
'''
(None, 10) #其中 None的取值是batch_size
(None, 10, 64)
input_shape:函数输入,尺寸为(batch_size, 10)的2D张量(矩阵的意思)
output_shape:函数输出,尺寸为(batch_size, 10,64)的3D张量
'''
model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
assert output_array.shape == (32, 10, 64)
print(output_array)
print(len(output_array))
print(len(output_array[1]))
print(len(output_array[1][1]))
'''
[
[[] [] [] ... [] [] []]
[[] [] [] ... [] [] []]
...
[[] [] [] ... [] [] []]
]
32:最外层维数32,32个样本
10:第二层维数10,每个样本用10个词表示
64:最内层维数64,每个词用64维向量表示
'''
举例2:
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
@author:
@contact:
@time:
@context:
"""
from keras.preprocessing.text import one_hot
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers.embeddings import Embedding
#我们定义一个文档集合存储于List,每个文档为list的一个元素,每个文档都对应一个标签,存储于labels
docs = ['Well done!',
'Good work',
'Great effort',
'nice work',
'Excellent!',
'Weak',
'Poor effort!',
'not good',
'poor work',
'Could have done better.']
labels = [1,1,1,1,1,0,0,0,0,0]
vocab_size = 50#估计的词汇表大小,设置时要比真实的词汇量大,不然会产生不同单词分配了相同的索引。
#通过索引对上面10个句子进行编码,one_hot编码映射到[1,vocab_size],不包括0
encoded_docs = [one_hot(d, vocab_size) for d in docs]
print(encoded_docs)
'''
[[3, 38], [20, 9], [36, 28], [15, 9], [21], [5], [45, 28], [1, 20], [45, 9], [31, 37, 38, 10]]
'''
# 文本编码成数字格式并padding到相同长度,这里长度设置为4,在后面补0,这也是为什么前面one-hot不会映射到0的原因。
max_length = 4
padded_docs = pad_sequences(encoded_docs, maxlen=max_length, padding='post')
print(padded_docs)
'''
[[ 3 38 0 0]
[20 9 0 0]
[36 28 0 0]
[15 9 0 0]
[21 0 0 0]
[ 5 0 0 0]
[45 28 0 0]
[ 1 20 0 0]
[45 9 0 0]
[31 37 38 10]]
'''
# define the model
model = Sequential()
model.add(Embedding(vocab_size, 8, input_length=max_length))
print(model.input_shape)
print(model.output_shape)
'''
(None, 4)
(None, 4, 8)
'''
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_docs)
assert output_array.shape == (10, 4, 8)
print(len(output_array))
print(len(output_array[1]))
print(len(output_array[1][1]))
print(output_array)
'''
10
4
8
[[[ 0.04572607 -0.03112372 0.01548124 0.0287031 0.03369636
-0.00907223 -0.02674365 0.0497326 ]
[ 0.02971635 0.01706659 0.01427769 0.02391822 0.02066484
0.03235774 0.00140371 -0.01571052]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.00637207 0.01458801 -0.02587212 0.0391363 0.04890009
-0.00473984 0.01941831 -0.03002635]
[ 0.02272599 0.01335565 -0.03088844 0.01404381 -0.00329325
0.0016606 0.00242132 -0.04546838]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[-0.02061187 -0.01111162 0.04552659 0.0447114 -0.02017692
0.04908471 0.00620199 0.04637216]
[ 0.04651392 -0.01801343 0.01927176 -0.03393314 -0.02526757
-0.00044692 0.01945822 0.01561001]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.04023353 -0.04503194 -0.01476847 0.04025214 -0.01467079
-0.04541937 0.00791662 0.04561491]
[ 0.02272599 0.01335565 -0.03088844 0.01404381 -0.00329325
0.0016606 0.00242132 -0.04546838]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.03823347 -0.01298066 -0.01494864 -0.00328387 -0.00303971
0.02827323 0.0077986 0.02893318]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[-0.00584649 -0.03266752 -0.043061 0.02855167 -0.0270277
0.01577503 -0.03172879 0.03462131]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.04572607 -0.03112372 0.01548124 0.0287031 0.03369636
-0.00907223 -0.02674365 0.0497326 ]
[ 0.04651392 -0.01801343 0.01927176 -0.03393314 -0.02526757
-0.00044692 0.01945822 0.01561001]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.04651392 -0.01801343 0.01927176 -0.03393314 -0.02526757
-0.00044692 0.01945822 0.01561001]
[ 0.00637207 0.01458801 -0.02587212 0.0391363 0.04890009
-0.00473984 0.01941831 -0.03002635]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.04572607 -0.03112372 0.01548124 0.0287031 0.03369636
-0.00907223 -0.02674365 0.0497326 ]
[ 0.02272599 0.01335565 -0.03088844 0.01404381 -0.00329325
0.0016606 0.00242132 -0.04546838]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]
[-0.00396045 -0.04596293 0.00576807 -0.0294588 -0.03388958
-0.0161563 -0.03131516 0.02193661]]
[[ 0.04632286 -0.03914303 -0.00696329 0.04238543 0.04322089
-0.02889879 0.0167807 0.03662675]
[ 0.03823347 -0.01298066 -0.01494864 -0.00328387 -0.00303971
0.02827323 0.0077986 0.02893318]
[ 0.02971635 0.01706659 0.01427769 0.02391822 0.02066484
0.03235774 0.00140371 -0.01571052]
[ 0.04477728 -0.02921386 0.03259372 -0.04354361 -0.02253401
0.04778937 0.03554988 0.01400479]]]
'''
参考资料