数据样本不平衡时处理方法(Resampling strategies for imbalanced datasets)

  • A+
所属分类:机器学习
摘要这一篇介绍一下关于样本不平衡的处理的方式,主要介绍两种采样方式,分别是上采样和下采样。这里主要介绍最简单的上采样和下采样,更多的内容见文章中的链接。

简介

这一部分讲一下样本平衡的一些做法。所有内容来自下面的链接。

下面这个参考资料很好,十分建议查看Resampling strategies for imbalanced datasets

为什么要做样本平衡

如果正负样本差别很大,或是类别与类别之间相差很大,那么模型就会偏向于预测最常出现的样本。同时,这样做最后可以获得较高的准确率,但是这个准确率不能说明模型有多好。

In a dataset with highly unbalanced classes, if the classifier always "predicts" the most common class without performing any analysis of the features, it will still have a high accuracy rate, obviously illusory.

解决办法

解决样本不平衡的问题,有两个大的方向是可以解决的。一个是under-sampling,另一个是over-sampling。(A widely adopted technique for dealing with highly unbalanced datasets is called resampling. It consists of removing samples from the majority class (under-sampling) and / or adding more examples from the minority class (over-sampling).)

Under-sampling

under-sampling我们可以理解为将较多的分类中的样本中取一些出来,使得较多的分类的数量与较少分类的数量相同。(这里采样的方式会有很多)

Over-sampling

所谓over-sampling,我们可以理解为将少的一部分样本进行重采样,使其变多。(这里重采样的方式会有很多)

下面这张图片概括了under-sampling和over-sampling两者区别。

数据样本不平衡时处理方法(Resampling strategies for imbalanced datasets)

当然,使用上面两种方式是会有代价的,如果使用undersampling,会出现丢失信息的问题。如果使用oversampling的方式,会出现过拟合的问题。

Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch). The simplest implementation of over-sampling is to duplicate random records from the minority class, which can cause overfitting. In under-sampling, the simplest technique involves removing random records from the majority class, which can cause loss of information.

简单实验

下面我们使用NSL-KDD数据集来做一下简单的实验。我们在这里只实现简单的over-sampling和under-sampling,关于一些别的采样方式可以参考上面的链接,我在这里再放一下。

数据集准备

  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt

下面导入数据集

  1. COL_NAMES = ["duration", "protocol_type", "service", "flag", "src_bytes",
  2.              "dst_bytes", "land", "wrong_fragment", "urgent", "hot", "num_failed_logins",
  3.              "logged_in", "num_compromised", "root_shell", "su_attempted", "num_root",
  4.              "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
  5.              "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate",
  6.              "srv_serror_rate", "rerror_rate", "srv_rerror_rate", "same_srv_rate",
  7.              "diff_srv_rate", "srv_diff_host_rate", "dst_host_count", "dst_host_srv_count",
  8.              "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
  9.              "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate",
  10.              "dst_host_rerror_rate", "dst_host_srv_rerror_rate", "labels"]
  11. # 导入数据集
  12. Trainfilepath = './NSL-KDD/KDDTrain+.txt'
  13. dfDataTrain = pd.read_csv(Trainfilepath, names=COL_NAMES, index_col=False)

我们简单查看一下各类攻击的分布。

  1. target_count = dfDataTrain.labels.value_counts()
  2. target_count.plot(kind='barh', title='Count (target)');
数据样本不平衡时处理方法(Resampling strategies for imbalanced datasets)

在这里,我们只对尝试其中的四种攻击,分别是back,neptune,smurf,teardrop。我们简单看一下这四种攻击的分布。

  1. DataBack = dfDataTrain[dfDataTrain['labels']=='back']
  2. DataNeptune = dfDataTrain[dfDataTrain['labels']=='neptune']
  3. DataSmurf = dfDataTrain[dfDataTrain['labels']=='smurf']
  4. DataTeardrop = dfDataTrain[dfDataTrain['labels']=='teardrop']
  5. DataAll = pd.concat([DataBack, DataNeptune, DataSmurf, DataTeardrop], axis=0, ignore_index=True).sample(frac=1) # 合并成为新的数据
  6. # 查看各类的分布
  7. target_count = DataAll.labels.value_counts()
  8. target_count.plot(kind='barh', title='Count (target)');
数据样本不平衡时处理方法(Resampling strategies for imbalanced datasets)

Over-Sampling

我们使用简单的过采样,即重复取值,使其样本个数增多。

  1. from imblearn.over_sampling import RandomOverSampler
  2. # 实现简单过采样
  3. ros = RandomOverSampler()
  4. X = DataAll.iloc[:,:41].to_numpy()
  5. y = DataAll['labels'].to_numpy()
  6. X_ros, y_ros = ros.fit_sample(X, y)
  7. print(X_ros.shape[0] - X.shape[0], 'new random picked points')
  8. # 组成pandas的格式
  9. DataAll = pd.DataFrame(X_ros, columns=COL_NAMES[:-1])
  10. DataAll['labels'] = y_ros
  11. # 进行可视化展示
  12. target_count = DataAll.labels.value_counts()
  13. target_count.plot(kind='barh', title='Count (target)');

简单看一下最终的结果,可以看到每个类别的样本现在都是40000+,相当于都和之前最多的样本的个数是相同的。

数据样本不平衡时处理方法(Resampling strategies for imbalanced datasets)

Under-Sampling

下面简单实现一下下采样,也是直接去掉比较多的类中的数据。

  1. from imblearn.under_sampling import RandomUnderSampler
  2. rus = RandomUnderSampler(return_indices=True)
  3. X = DataAll.iloc[:,:41].to_numpy()
  4. y = DataAll['labels'].to_numpy()
  5. X_rus, y_rus, id_rus = rus.fit_sample(X, y)
  6. # 组成pandas的格式
  7. DataAll = pd.DataFrame(X_rus, columns=COL_NAMES[:-1])
  8. DataAll['labels'] = y_rus
  9. # 进行绘图
  10. target_count = DataAll.labels.value_counts()
  11. target_count.plot(kind='barh', title='Count (target)');

可以看到现在每个样本的个数都是800+,这样就完成了under-sampling.

数据样本不平衡时处理方法(Resampling strategies for imbalanced datasets)

这里只是简单的介绍关于上采样和下采样的方式,还有一些其他的采样方式可以参考上面的链接。

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: