4. 朴素贝叶斯分类新闻文本
原创
©著作权归作者所有:来自51CTO博客作者避风塘主的原创作品,请联系作者获取转载授权,否则将追究法律责任
4. 朴素贝叶斯分类新闻文本
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
# 加载数据
newsgroups_train = fetch_20newsgroups(subset='train', categories=['alt.atheism', 'sci.space'])
newsgroups_test = fetch_20newsgroups(subset='test', categories=['alt.atheism', 'sci.space'])
# 文本特征提取
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(newsgroups_train.data)
X_test = vectorizer.transform(newsgroups_test.data)
y_train, y_test = newsgroups_train.target, newsgroups_test.target
# 创建并训练模型
model = MultinomialNB()
model.fit(X_train, y_train)
# 预测
y_pred = model.predict(X_test)
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred))
# 可视化混淆矩阵
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()