循序渐进,学会用pyecharts绘制桑基图

桑基图介绍

桑基图是比较冷门的可视化图形,知道的人不多,但它的可视化效果很惊艳,以后肯定会有越来越多的人使用,我平时使用桑基图,主要是用其绘制可视化图形做PPT。

桑基图(Sankey diagram),也称为桑基能量分流图、桑基能量平衡图,因1898年Sankey绘制的“蒸汽机的能源效率图”而闻名,此后便以其名命名为“桑基图”。

桑基图是一种特殊类型的流程图,展示了数据的流动关系和演变关系,始端宽度的总和与末端宽度的总和相等,保持数据的平衡。

桑基图通常应用于能源、材料成分、金融等数据的可视化分析,随着接触的人越来越多,正逐步被应用于更多领域的可视化分析中。

用Python绘制桑基图并不算难,用过两三次就熟悉绘图套路了,接下来结合具体的实例,本文将一步一步介绍如何用Python绘制桑基图。

数据预处理

假设有某“路边社”采访了1000名打工人回家过年的交通方式,数据汇总如下。

循序渐进,学会用pyecharts绘制桑基图_Python绘制桑基图

从原始数据看,受访者有性别和交通工具类型的信息,如果要可视化不同性别的人里分别有多少人使用了哪种交通方式,体现出流动的效果,最适合的就是使用桑基图。

要绘制桑基图,需要先分析数据中有哪些节点,以及节点之间的流量关系,如性别中有两个节点:男、女,交通工具中有高铁、飞机、汽车、骑马四个节点,节点之间的流量关系,如男生中有320个人乘坐高铁,则流量关系为从“男”节点到“高铁”节点的流量为320,要提取这些信息,需要先对数据进行预处理。

第一步,先调整表格的数据,把性别和交通工具信息都放到表格数据里,而不像原统计数据里是作为表格的列名和行名。用pandas处理成如下数据:

import pandas as pd

df = pd.DataFrame({
    "性别": ['男', '男', '男', '男', '女', '女', '女', '女'],
    "交通工具": ['高铁', '飞机', '汽车', '骑马', '高铁', '飞机', '汽车', '骑马'],
    "人数": [320, 47, 190, 5, 296, 82, 60, 0]
})
print(df)

Output:

性别 交通工具 人数
0  男   高铁  320
1  男   飞机   47
2  男   汽车  190
3  男   骑马    5
4  女   高铁  296
5  女   飞机   82
6  女   汽车   60
7  女   骑马    0

第二步,获取数据中所有的节点,从df中提取出数据中第一列和第二列的唯一值,并处理成字典的格式。(pyecharts绘制桑基图时节点使用字典格式)

nodes = []
for i in range(2):
    data = df.iloc[:, i].unique()
    for d in data:
        node = dict()
        node['name'] = d
        nodes.append(node)
print(nodes)

Output:

[{'name': '男'}, {'name': '女'}, {'name': '高铁'}, {'name': '飞机'}, {'name': '汽车'}, {'name': '骑马'}]

第三步,获取所有节点之间的数据流量关系,并处理成字典格式。对于数据流量为0的情况(如数据中骑马的女生为0),可以保留也可以不保留,我的处理方式是不保留。

linkes = []
for i in df.values:
    lin = dict()
    lin['source'] = i[0]
    lin['target'] = i[1]
    lin['value'] = i[2]
    if lin['value'] > 0:
        linkes.append(lin)
print(linkes)

Output:

[{'source': '男', 'target': '高铁', 'value': 320}, {'source': '男', 'target': '飞机', 'value': 47}, {'source': '男', 'target': '汽车', 'value': 190}, {'source': '男', 'target': '骑马', 'value': 5}, {'source': '女', 'target': '高铁', 'value': 296}, {'source': '女', 'target': '飞机', 'value': 82}, {'source': '女', 'target': '汽车', 'value': 60}]

绘制桑基图

Python绘制桑基图使用pyecharts是最方便的,直接调用pyecharts中的Sankey模块即可。

from pyecharts.charts import Sankey
from pyecharts import options as opts

san = Sankey(init_opts=opts.InitOpts(width='800px', height='400px'))
san.add(
    '',  # 图例名称
    nodes,  # 传入节点数据
    linkes,  # 传入节点间的流量数据
    linestyle_opt=opts.LineStyleOpts(opacity=0.3, curve=0.5, color='source'),  # 设置透明度、弯曲度、颜色
    label_opts=opts.LabelOpts(position='right'),  # 设置标签显示位置等样式
    node_gap=10,  # 设置节点之间的间距
).set_global_opts(title_opts=opts.TitleOpts(title='春节回家交通方式', pos_left='35%'))
san.render('sankey1.html')

循序渐进,学会用pyecharts绘制桑基图_sankey_02

代码说明:

pyecharts绘制桑基图直接使用Sankey模块,安装pyecharts并导入即可调用。先实例化一个Sankey模块的对象,然后调用add()方法添加数据和设置各种样式。

add()方法的第一个参数是图例名称,可以直接为空,这里不用图例也行。第二个参数和第三个参数传入预处理时的节点数据和节点间的流量数据。linestyle_opt参数用于设置图形的透明度、弯曲度、颜色等,label_opts参数用于设置标签的显示位置,node_gap参数用于设置节点间的间距。

set_global_opts中通过title_opts参数设置标题的样式和位置等,有很多样式设置是pyecharts中对多种图形通用的设置,可以点进源码看每一项设置的用途,也可以到pyecharts的官方文档里查看,结合自己的修改尝试,基本都能明白每个参数用来做什么。

绘制多层桑基图

对于更复杂的数据,可能需要三层及以上的桑基图来做可视化。例如上面的数据除了统计大家的交通工具,还统计了交通工具的座次等信息,如下:

循序渐进,学会用pyecharts绘制桑基图_Python可视化_03

对于这份数据,可以用三层的桑基图来可视化,还是一样,先对数据预处理,调整表格将数据信息放进数据中。

df2 = pd.DataFrame({
    "性别": ['男', '男', '男', '男', '男', '男', '男', '男',
            '女', '女', '女', '女', '女', '女', '女', '女'],
    "交通工具": ['高铁', '高铁', '高铁', '飞机', '飞机', '汽车', '汽车', '骑马',
            '高铁', '高铁', '高铁', '飞机', '飞机', '汽车', '汽车', '骑马'],
    "座次": ['商务座', '一等座', '二等座', '头等舱', '经济舱', '大巴', '自驾', '真骑马',
            '商务座', '一等座', '二等座', '头等舱', '经济舱', '大巴', '自驾', '真骑马'],
    "人数": [5, 15, 300, 2, 45, 80, 110, 5, 16, 50, 230, 7, 75, 55, 5, 0]
})
print(df2)

Output:

性别 交通工具   座次   人数
0   男   高铁  商务座    5
1   男   高铁  一等座   15
2   男   高铁  二等座  300
3   男   飞机  头等舱    2
4   男   飞机  经济舱   45
5   男   汽车   大巴   80
6   男   汽车   自驾  110
7   男   骑马  真骑马    5
8   女   高铁  商务座   16
9   女   高铁  一等座   50
10  女   高铁  二等座  230
11  女   飞机  头等舱    7
12  女   飞机  经济舱   75
13  女   汽车   大巴   55
14  女   汽车   自驾    5
15  女   骑马  真骑马    0

获取节点时取三列的数据。

nodes = []
for i in range(3):
    data = df2.iloc[:, i].unique()
    for d in data:
        node = dict()
        node['name'] = d
        nodes.append(node)
print(nodes)

Output:

[{'name': '男'}, {'name': '女'}, {'name': '高铁'}, {'name': '飞机'}, {'name': '汽车'}, {'name': '骑马'}, {'name': '商务座'}, {'name': '一等座'}, {'name': '二等座'}, {'name': '头等舱'}, {'name': '经济舱'}, {'name': '大巴'}, {'name': '自驾'}, {'name': '真骑马'}]

由于节点数据有三列,而节点间的数据流量关系是一个source对应一个target,为了方便处理,先对df2中的数据做一次分组聚合处理。

first = df2.groupby(['性别', '交通工具'])['人数'].sum().reset_index()
second = df2.groupby(['交通工具', '座次'])['人数'].sum().reset_index()
first.columns = ['source', 'target', 'value']
second.columns = ['source', 'target', 'value']
result = pd.concat([first, second])
print(result)

Output:

0      女     汽车     60
1      女     飞机     82
2      女     骑马      0
3      女     高铁    296
4      男     汽车    190
5      男     飞机     47
6      男     骑马      5
7      男     高铁    320
0     汽车     大巴    135
1     汽车     自驾    115
2     飞机    头等舱      9
3     飞机    经济舱    120
4     骑马    真骑马      5
5     高铁    一等座     65
6     高铁    二等座    530
7     高铁    商务座     21

然后用相同的方式得到节点间的流量关系。

linkes = []
for i in result.values:
    lin = dict()
    lin['source'] = i[0]
    lin['target'] = i[1]
    lin['value'] = i[2]
    if lin['value'] > 0:
        linkes.append(lin)
print(linkes)

Output:

[{'source': '女', 'target': '汽车', 'value': 60}, {'source': '女', 'target': '飞机', 'value': 82}, {'source': '女', 'target': '高铁', 'value': 296}, {'source': '男', 'target': '汽车', 'value': 190}, {'source': '男', 'target': '飞机', 'value': 47}, {'source': '男', 'target': '骑马', 'value': 5}, {'source': '男', 'target': '高铁', 'value': 320}, {'source': '汽车', 'target': '大巴', 'value': 135}, {'source': '汽车', 'target': '自驾', 'value': 115}, {'source': '飞机', 'target': '头等舱', 'value': 9}, {'source': '飞机', 'target': '经济舱', 'value': 120}, {'source': '骑马', 'target': '真骑马', 'value': 5}, {'source': '高铁', 'target': '一等座', 'value': 65}, {'source': '高铁', 'target': '二等座', 'value': 530}, {'source': '高铁', 'target': '商务座', 'value': 21}]

节点和节点间的流量关系准备好后,用相同的代码即可绘制出三层的桑基图。

san2 = Sankey()
san2.add(
    '',  # 图例名称
    nodes,  # 传入节点数据
    linkes,  # 传入节点间的流量数据
    linestyle_opt=opts.LineStyleOpts(opacity=0.3, curve=0.5, color='source'),  # 设置透明度、弯曲度、颜色
    label_opts=opts.LabelOpts(position='right'),  # 设置标签显示位置等样式
    node_gap=10,  # 设置节点之间的间距
).set_global_opts(title_opts=opts.TitleOpts(title=''))
san2.render('sankey2.html')

循序渐进,学会用pyecharts绘制桑基图_pyecharts_04

可以看出,两层桑基图和三层桑基图的绘图代码相同,区别在于如何对数据做预处理,得到不同的 nodes(节点)和 linkes(节点间的流量关系),这也正是绘制桑基图的关键点。

因此,对于更复杂的桑基图,根本原因是数据间的关系复杂了,将数据之间的关系处理好传给Sankey模块,它就可以完成可视化。

这里有一个小注意点,在绘制多层桑基图时,节点名不能重复,例如上面数据中的“骑马”在第二层向第三层流动时,数据没有发生变化,如果在第三层的节点也叫“骑马”,节点名重复了,图形不能正常生成。

功能进阶

绘制桑基图的方法和关键已经介绍完成,不过它的功能远不止这些,还可以有很多的变化,如将方向变成垂直方向、自定义节点颜色、自定义显示内容等。本文再介绍一个用JsCode自定义显示内容的方法,其他的内容就不逐一介绍了。

例如上面的图形中,每个节点都只显示了节点的名称,没有显示节点的数值,下面通过JsCode自定义将数值也显示出来。

第一步,用分组聚合的方式得到每一个节点的数值。(这里再次用到了分组聚合,可以参考:Pandas知识点-详解分组函数groupby)

gender = df2.groupby('性别')['人数'].sum().reset_index()
transport = df2.groupby('交通工具')['人数'].sum().reset_index().sort_values('人数', ascending=False)
seat = df2.groupby('座次')['人数'].sum().reset_index()
gender.columns = ['name', 'value']
transport.columns = ['name', 'value']
seat.columns = ['name', 'value']
result = pd.concat([gender, transport, seat])
print(result)

Output:

0    女    438
1    男    562
3   高铁    616
0   汽车    250
1   飞机    129
2   骑马      5
0  一等座     65
1  二等座    530
2  商务座     21
3   大巴    135
4  头等舱      9
5  真骑马      5
6  经济舱    120
7   自驾    115

第二步,将节点的数值添加到节点的字典中。

nodes = []
for i in result.values:
    node = dict()
    node['name'] = i[0]
    node['value'] = i[1]
    nodes.append(node)
print(nodes)

Output:

[{'name': '女', 'value': 438}, {'name': '男', 'value': 562}, {'name': '高铁', 'value': 616}, {'name': '汽车', 'value': 250}, {'name': '飞机', 'value': 129}, {'name': '骑马', 'value': 5}, {'name': '一等座', 'value': 65}, {'name': '二等座', 'value': 530}, {'name': '商务座', 'value': 21}, {'name': '大巴', 'value': 135}, {'name': '头等舱', 'value': 9}, {'name': '真骑马', 'value': 5}, {'name': '经济舱', 'value': 120}, {'name': '自驾', 'value': 115}]

第三步,在可视化时,导入pyecharts.commons.utils中的JsCode模块,通过自定义函数的方式设置标签的显示格式,例如要同时显示节点名和节点的值,则在自定义函数中将两者拼接返回(参考下方JsCode中的代码)。

from pyecharts.commons.utils import JsCode

san3 = Sankey()
san3.add(
    '',  # 图例名称
    nodes,  # 传入节点数据
    linkes,  # 传入节点间的流量数据
    linestyle_opt=opts.LineStyleOpts(opacity=0.35, curve=0.5, color='source'),  # 设置透明度、弯曲度、颜色
    label_opts=opts.LabelOpts(position='right',
                              formatter=JsCode(
                                  """function(params){return params.name+':'+params.value}"""
                              )),  # 设置标签显示位置等样式
    node_gap=10,  # 设置节点之间的间距
).set_global_opts(title_opts=opts.TitleOpts(title=''))
san3.render('sankey3.html')

循序渐进,学会用pyecharts绘制桑基图_Python绘制桑基图_05

最后,有可能桑基图的节点排列顺序不符合你的预期,在add()方法中有一个参数 is_draggable ,这个参数的默认值为True,节点默认是可以用鼠标随意拖拽的,可视化的节点位置不满意,可以拖拽进行微调。

参考文档:
[1] pyecharts官方文档:https://gallery.pyecharts.org/#/Bar/README