伯努利模型
BernoulliNB
介绍:
- 多项式朴素贝叶斯可同时处理二项分布(抛硬币)和多项式分布(掷骰子),其中二项式分布又叫做伯努利分布,它是一种现实中常见的,并且拥有很多优越数学性质的分布。因此,既然有着多项式朴素贝叶斯,我们自然也就又专门用来处理二项分布的朴素贝叶斯:伯努利朴素贝叶斯。
- 与多项式模型一样,伯努利模型适用于离散特征的情况,所不同的是,数据集中可以存在多个特征,但每个特征都是二分类的,伯努利模型中每个特征的取值只能是1和0(以文本分类为例,某个单词在文档中出现过,则其特征值为1,否则为0),伯努利模型需要比MultinomiaNB多定义一个二值化的方法,该方法会接受一个阈值并将输入的特征二值化(1, 0).当然也可以直接采用MulinomlialNB,但需要预先将输入的特征二值化
作用:
- 伯努利朴素贝叶斯与多项式朴素贝叶斯非常相似,都常用于处理文本分类数据,但由于伯努利朴素贝叶斯是处理二项分布,所以它更加在意的是“是”与“否”。判断一篇文章是否属于体育资讯,而不是说属于体育类还是娱乐类
API:
- class sklearn.naive_bayes.BernoulliNB(alpha=1.0, binarize=0.0, fit_prior=True, class_prior=None)
- 参数介绍:
- alpha:拉普拉斯平滑系数
- binarize:可以是数值或者不输入。如果不输入,则BernoulliNB认为每个数据特征都已经是二元(二值化)的。否则的话,小于binarize的会归为一类,大于binarize的回归为另外一类
- 二值化操作
大于阈值的为1, 小于等于的为0
from sklearn import preprocessing
import numpy as np
x = np.array([[1, -1, 4, 8, 3, 4, 5],
[3, 34, 5, 5, 6, 4, 4],
[3, 34, 5, 5, 6, 4, 4]])
binarizer = preprocessing.Binarizer(threshold=5)
X_binarizer = binarizer.transform(x)
print("二值化,(阈值:5)", X_binarizer)
二值化,(阈值:5) [[0 0 0 1 0 0 0]
[0 1 0 0 1 0 0]
[0 1 0 0 1 0 0]]
Process finished with exit code 0
- 应用
from sklearn import preprocessing
import numpy as np
from sklearn.naive_bayes import MultinomialNB, GaussianNB, BernoulliNB, ComplementNB
import sklearn.datasets as datasets
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
news = datasets.fetch_20newsgroups(subset='all')
feature = news.data # 返回的是列表,列表中为一篇篇文章
target = news.target # 返回的ndarray,储存的是每一篇文章的类别
x_train, x_test, y_train, y_test = train_test_split(feature, target, test_size=0.2)
# 生成文章特征词
tf = TfidfVectorizer() # 实例工具类
x_train = tf.fit_transform(x_train) # 返回训练集所有文章中每个词的重要性
x_test = tf.transform(x_test) # 返回测试集所有文章中每个词的重要性
print(x_train.shape)
print(x_test.shape)
# 使用模型进行文章分类
binarizer = preprocessing.Binarizer(threshold=5)
mlt = BernoulliNB()
mlt.fit(x_train, y_train)
y_predict = mlt.predict(x_test)
print('预测文章类别为:', y_predict)
print('真实文章类别为:', y_test)
print('准确率为:', mlt.score(x_test, y_test))
(15076, 147676)
(3770, 147676)
预测文章类别为: [11 7 6 ... 11 6 10]
真实文章类别为: [11 7 15 ... 11 17 10]
准确率为: 0.713262599469496
Process finished with exit code 0