前言
最近在学习机器挖掘内容,其中有一个问题应该是大家都会碰到的问题,就是如果样本数据中类别样本个数相差巨大该如何处理,比如,A类别100个样本,B类别10000个样本。这样类别差别训练模型实际效果并不理想。所以需要一个方法来解决这个问题。
技巧
欠采样(undersampling)和过采样(oversampling)会对模型带来怎样的影响?这篇文章讲解了很多东西,其实大家可以看看,其中也有用模型来实际训练模型来得到一些结果,可以看出来Oversampling方法和 SMOTE方法都非常有效,对实验提升较大。作者总结了一个方法:使用过采样(或 smote)+ XGBOOST 比较适合不平衡的数据,可以将这个方法作为一个Baseline方法。
应用场景
其实很多场景里会出现这个问题,比如:垃圾邮件识别,微博垃圾用户,虚假信息识别等,这些都是反例样本很少、正例样本多的场景,这里都可以应用本文的技巧来处理数据。
代码实测
from collections import Counter
from sklearn.datasets import make_classification
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
import pandas as pd
import numpy as np
## 定义的一个随机森林的方法,参数是训练集,利用训练集来训练,并且用来做验证
def RandomForestTest(x_train,y_train):
RF = RandomForestClassifier(n_jobs=3)
# 训练样本
RF.fit(x_train, y_train)
y_pred_train = RF.predict(x_train)
f1 = f1_score(y_train,y_pred_train)
print("train F1 : {:.3f}".format(f1))
### 读取数据集
dataPath = "datalab/37451/"
dataFile = "creditcard.csv"
data = pd.read_csv(dataPath + dataFile , encoding='utf-8')
### 看一下 class 的值信息
print(data['Class'].value_counts())
### 负采样 方法,
def dataProcess(data):
# Shuffle the Dataset.
shuffled_df = data.sample(frac=1,random_state=4)
# Put all the fraud class in a separate dataset.
fraud_df = shuffled_df.loc[shuffled_df['Class'] == 1]
#Randomly select 492 observations from the non-fraud (majority class)
non_fraud_df=shuffled_df.loc[shuffled_df['Class']== 0].sample(n=492,random_state=42)
# Concatenate both dataframes again
normalized_df = pd.concat([fraud_df, non_fraud_df])
print(normalized_df.shape)
return normalized_df
normalized_df = dataProcess(data)
normalized_df.head(2)
### 不同的 数据集到方法中,看下拟合效率
x = data.drop('Class', axis=1)
y = data['Class']
RandomForestTest(x,y)
print(" - - - - 负采样后 - - - - ")
x_nor = normalized_df.drop('Class', axis=1)
y_nor = normalized_df['Class']
RandomForestTest(x_nor,y_nor)
结果:
0 284315
1 492
Name: Class, dtype: int64
train F1 : 0.969
- - - - 负采样后 - - - -
train F1 : 0.995
不过这个后续还要多补充一些,比如 上采样的实验,以及测试集的指标来说明情况,因为目前都是用训练集来做,第二个训练集的数据还特别少。
完整的代码在我的github账号里,欢迎大家来看看,附上链接:inbalanceDataTest.ipynb
参考博客
欠采样(undersampling)和过采样(oversampling)会对模型带来怎样的影响?