Python量化投资:技术、模型与策略
上QQ阅读APP看书,第一时间看更新

4.2 seaborn

seaborn(官网:https://seaborn.pydata.org/)是一个很好用的统计图形库。它是基于Matplotlib开发的,能够很好地支持Pandas和NumPy的数据结构。

由于seaborn已经封装好了很多功能,因此在做统计相关的图形的操作时,会比Matplotlib更容易上手一些,也更实用一些。在进行可视化分析时,一般来说,笔者都会优先考虑使用seaborn。

本节将使用几个例子来说明seaborn的使用方法。这里主要是使用seaborn自带的示例数据。

首先是生成最常用的线性回归图,示例代码如下:


%matplotlib inline  
import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

# 加载数据
iris = sns.load_dataset("iris")

# 绘图
g = sns.lmplot(x="sepal_length", y="sepal_width", hue="species",
               truncate=True, size=6, data=iris)

# 更改x,y轴的标签
g.set_axis_labels("Sepal length (mm)", "Sepal width (mm)")

运行结果如图4-5所示。

图 4-5

通过如下的命令来观察一下数据:


iris.head()

运行结果如图4-6所示。

图 4-6

可以看到,iris是一个DataFrame,在图4-6中,我们按照species分类,绘制了sepal_length和sepal_width的线性回归图,不同的类别是由不同的颜色来表现的。线性回归图包含了散点图和线性回归的拟合直线。

seaborn也可以与Matplotlib无缝对接。比如,在如下示例中,我们将三个直方图绘制在一起。最开始使用的是Matplotlib的子图功能,示例代码如下:


import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="white", context="talk")
rs = np.random.RandomState(7)


# 将图分为3*1的子图
f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6), sharex=True)

# 生成数据
x = np.array(list("ABCDEFGHI"))
y1 = np.arange(1, 10)

# 绘制柱状图
sns.barplot(x, y1, palette="BuGn_d", ax=ax1)

# 更改y标签
ax1.set_ylabel("Sequential")

# 生成新数据,绘制第二个图
y2 = y1 - 5
sns.barplot(x, y2, palette="RdBu_r", ax=ax2)
ax2.set_ylabel("Diverging")

# 重新排列数据,绘制第三个图
y3 = rs.choice(y1, 9, replace=False)
sns.barplot(x, y3, palette="Set3", ax=ax3)
ax3.set_ylabel("Qualitative")

# 移除边框
sns.despine(bottom=True)

# 将y轴的tick设置为空(美化图形)
plt.setp(f.axes, yticks=[])

# 设置三个图的上下间隔
plt.tight_layout(h_pad=3)

运行结果如图4-7所示。

图 4-7

热力图也是一种常见统计图,热力图可以使用不同的颜色以及颜色的深浅来表示不同的数值,这样就可以方便直观地观察截面数据了。使用seaborn也可以很方便地绘制热力图。以下示例代码通过pivot生成了一个二维的flights数据,横坐标(column)是年度,纵坐标(index)是月份。输入flights数据画出对应的热力图,示例代码如下:


import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

# 加载数据
flights_long = sns.load_dataset("flights")
flights = flights_long.pivot("month", "year", "passengers")

# 绘制热力图
f, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(flights, annot=True, fmt="d", linewidths=.5, ax=ax)

运行结果如图4-8所示。

图 4-8