一、简介
在机器学习中,经常需要通过散点图查看原始数据的分布情况,从而对特征和算法的选择进行初步判断。
散点图可以形象展示直角坐标系中两个变量之间的关系。在散点图中 ,每个数据点的位置实际上就是两个变量的值。变量间的任何关系都可以拿散点图来表示。
matplotlib绘图功能模仿MATLAB,非常方便和强大。下面,本文将详细介绍如何使用matplotlib画出好看实用的散点图。
如果你对matplotlib完全不熟悉,可以先花10分钟去我的另一篇博客学习一下基本操作:
10分钟带你从零上手matplotlib数据可视化
需要进一步深入了解的朋友可以查看 matplotlib.pyplot.scatter 官方文档
二、2D散点图参数及实例
1. 常用参数详解
import matplotlib.pyplot as plt
plt.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None,
vmin=None, vmax=None, alpha=None, linewidths=None,
verts=None, edgecolors=None, *, data=None, **kwargs)
- x,y传入数组,形如
shape(n,)
,表示每个点的横、纵坐标 - s传入标量或数组,形如
shape(n,)
,表示每个点标记的大小,可选,默认为[‘lines.markersize’] ** 2 - c传入颜色,或颜色序列,表示点标记的颜色,b=蓝色,g=绿色,r=红色,y=黄色,k=黑色,w=白色,c=蓝绿色,m=洋红
- marker,传入字符串,表示点标记的样式,默认为’o’,常用的有
'^', '*', 'o','+','x'
,
更多样式可以查看官方文档 - linewidths传入标量或数组,形如
shape(n,)
,表示标记的边框线宽,默认为None - edgecolors传入颜色或颜色序列,表示标记边框的颜色,默认为’face’,传入‘face’表示与标记颜色相同,传入’none’表示无边框。
2. 最简单的2D散点图实例
import matplotlib.pyplot as plt
x = [1, 2, 3, 4]
y = [1, 2, 3, 4]
plt.scatter(x, y, s=[10, 20, 50, 100], c=['r', 'y', 'g', 'b'])
plt.show()
从图中可以看出来,的确是s控制了每个点的大小,c控制了每个点的颜色。
3. 机器学习中的2D散点图实例
下面我们先用sklearn中经典的iris分类数据画一个二维散点图
from sklearn import datasets
import matplotlib.pyplot as plt
#从sklearn中获取经典的iris数据
iris = datasets.load_iris() #iris.data为150x4矩阵
x1 = iris.data[:, 1] #获取第二列特征值
x2 = iris.data[:, 2] #获取第三列特征值
y = iris.target #y是分类值:0,1,2
plt.scatter(x1, x2, c=y) #将y作为参数传给c能够很方便的区分不同类别的颜色
plt.title('Iris Classification')
plt.xlabel('Petal length')
plt.ylabel('Petal width')
plt.show()
三、3D散点图参数及实例
Matplotlib 绘制 3D 图像主要通过 mpl_toolkits.mplot3d 模块实现,但由于三维图像实际上是在二维画布上展示,因此同样需要载入 pyplot 模块。
备注:mpl_toolkits.mplot3d这个模块不需要另外安装,matplotlib中已自带。
1. 常用参数详解
Axes3D.scatter(xs, ys, zs=0, zdir='z', s=20, c=None,
depthshade=True, *args, **kwargs)
- xs, ys控制x轴和y轴坐标
- zs控制z轴坐标,默认为0,如果传入1个标量,那么就是所有点在同一高度,传入数组就是与xs, ys一一对应的高度。
- s控制点的大小
- c控制点的颜色
2. 3D散点图实例
还是用上面的数据,我们取iris.data中前三个特征画散点图
from sklearn import datasets
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#从sklearn中获取经典的iris数据
iris = datasets.load_iris() #iris.data为150x4矩阵
x1 = iris.data[:, 0] #获取第1列特征值
x2 = iris.data[:, 1] #获取第2列特征值
x3 = iris.data[:, 2] #获取第3列特征值
y = iris.target #y是分类值:0,1,2
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x1, x2, x3, c=y)
plt.show()