Abracadabra

python data analysis learning note ch08

绘图和可视化

1
2
3
4
5
6
7
8
9
10
from __future__ import division
from numpy.random import randn
import numpy as np
import os
import matplotlib.pyplot as plt
np.random.seed(12345)
plt.rc('figure', figsize=(10, 6))
from pandas import Series, DataFrame
import pandas as pd
np.set_printoptions(precision=4)
1
2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
1
%matplotlib inline

matplotlib API 入门

1
import matplotlib.pyplot as plt

Figure 和 Subplot

1
fig = plt.figure()
1
ax1 = fig.add_subplot(2, 2, 1)
1
2
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
1
2
from numpy.random import randn
plt.plot(randn(50).cumsum(), 'k--')

png

1
2
_ = ax1.hist(randn(100), bins=20, color='k', alpha=0.3)
ax2.scatter(np.arange(30), np.arange(30) + 3 * randn(30))
1
plt.close('all')
1
2
fig, axes = plt.subplots(2, 3)
axes

png

调整subplot周围的间距

1
2
plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
wspace=None, hspace=None)
1
2
3
4
5
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
for i in range(2):
for j in range(2):
axes[i, j].hist(randn(500), bins=50, color='k', alpha=0.5)
plt.subplots_adjust(wspace=0, hspace=0)
(array([  2.,   0.,   3.,   2.,   1.,   1.,   0.,   3.,   5.,   8.,   9.,
          9.,  10.,  18.,  34.,  13.,  24.,  30.,  24.,  24.,  25.,  20.,
         34.,  20.,  30.,  30.,  19.,  14.,  14.,   8.,  19.,  14.,   7.,
          3.,   7.,   2.,   7.,   2.,   2.,   0.,   1.,   0.,   0.,   0.,
          1.,   0.,   0.,   0.,   0.,   1.]),
 array([-2.9493, -2.8118, -2.6743, -2.5367, -2.3992, -2.2617, -2.1241,
        -1.9866, -1.849 , -1.7115, -1.574 , -1.4364, -1.2989, -1.1614,
        -1.0238, -0.8863, -0.7487, -0.6112, -0.4737, -0.3361, -0.1986,
        -0.0611,  0.0765,  0.214 ,  0.3516,  0.4891,  0.6266,  0.7642,
         0.9017,  1.0392,  1.1768,  1.3143,  1.4519,  1.5894,  1.7269,
         1.8645,  2.002 ,  2.1395,  2.2771,  2.4146,  2.5522,  2.6897,
         2.8272,  2.9648,  3.1023,  3.2398,  3.3774,  3.5149,  3.6525,
         3.79  ,  3.9275]),
 <a list of 50 Patch objects>)






(array([  1.,   1.,   0.,   2.,   0.,   1.,   1.,   5.,   7.,   4.,   5.,
          8.,  12.,  12.,  13.,  15.,  17.,  13.,  22.,  30.,  21.,  24.,
         17.,  20.,  20.,  20.,  18.,  26.,  16.,  24.,  19.,   8.,  14.,
         15.,   7.,  11.,   5.,   4.,   9.,   7.,   6.,   1.,   6.,   2.,
          4.,   2.,   0.,   2.,   1.,   2.]),
 array([-2.595 , -2.4898, -2.3845, -2.2793, -2.1741, -2.0688, -1.9636,
        -1.8584, -1.7531, -1.6479, -1.5427, -1.4374, -1.3322, -1.227 ,
        -1.1217, -1.0165, -0.9112, -0.806 , -0.7008, -0.5955, -0.4903,
        -0.3851, -0.2798, -0.1746, -0.0694,  0.0359,  0.1411,  0.2463,
         0.3516,  0.4568,  0.562 ,  0.6673,  0.7725,  0.8777,  0.983 ,
         1.0882,  1.1935,  1.2987,  1.4039,  1.5092,  1.6144,  1.7196,
         1.8249,  1.9301,  2.0353,  2.1406,  2.2458,  2.351 ,  2.4563,
         2.5615,  2.6667]),
 <a list of 50 Patch objects>)






(array([  1.,   0.,   1.,   0.,   0.,   1.,   0.,   1.,   1.,   0.,   4.,
          1.,   4.,   5.,  11.,   8.,   6.,  11.,  13.,  13.,  17.,  18.,
         20.,  27.,  32.,  29.,  31.,  22.,  21.,  31.,  29.,  19.,  22.,
         18.,  10.,  18.,  11.,  12.,   9.,   6.,   2.,   3.,   3.,   3.,
          2.,   1.,   1.,   1.,   0.,   1.]),
 array([-3.7454, -3.6052, -3.4651, -3.325 , -3.1849, -3.0448, -2.9047,
        -2.7646, -2.6244, -2.4843, -2.3442, -2.2041, -2.064 , -1.9239,
        -1.7837, -1.6436, -1.5035, -1.3634, -1.2233, -1.0832, -0.9431,
        -0.8029, -0.6628, -0.5227, -0.3826, -0.2425, -0.1024,  0.0377,
         0.1779,  0.318 ,  0.4581,  0.5982,  0.7383,  0.8784,  1.0185,
         1.1587,  1.2988,  1.4389,  1.579 ,  1.7191,  1.8592,  1.9994,
         2.1395,  2.2796,  2.4197,  2.5598,  2.6999,  2.84  ,  2.9802,
         3.1203,  3.2604]),
 <a list of 50 Patch objects>)






(array([  1.,   0.,   0.,   1.,   1.,   0.,   0.,   0.,   0.,   1.,   2.,
          5.,   9.,   8.,   6.,   2.,  11.,  17.,  10.,  13.,  10.,  14.,
         12.,  27.,  17.,  28.,  27.,  25.,  14.,  24.,  25.,  38.,  13.,
         24.,  15.,  10.,  17.,  14.,  13.,   8.,   7.,  10.,   3.,   7.,
          2.,   5.,   2.,   0.,   1.,   1.]),
 array([-3.4283, -3.3066, -3.185 , -3.0633, -2.9417, -2.8201, -2.6984,
        -2.5768, -2.4551, -2.3335, -2.2119, -2.0902, -1.9686, -1.847 ,
        -1.7253, -1.6037, -1.482 , -1.3604, -1.2388, -1.1171, -0.9955,
        -0.8739, -0.7522, -0.6306, -0.5089, -0.3873, -0.2657, -0.144 ,
        -0.0224,  0.0993,  0.2209,  0.3425,  0.4642,  0.5858,  0.7074,
         0.8291,  0.9507,  1.0724,  1.194 ,  1.3156,  1.4373,  1.5589,
         1.6806,  1.8022,  1.9238,  2.0455,  2.1671,  2.2887,  2.4104,
         2.532 ,  2.6537]),
 <a list of 50 Patch objects>)

png

颜色、标记和线型

1
plt.figure()
1
plt.plot(randn(30).cumsum(), 'ko--')

png

1
plt.close('all')
1
2
3
4
data = randn(30).cumsum()
plt.plot(data, 'k--', label='Default')
plt.plot(data, 'k-', drawstyle='steps-post', label='steps-post')
plt.legend(loc='best')

png

刻度、标签和图例

设置标题、轴标签、刻度以及刻度标签

1
2
3
4
5
6
7
8
fig = plt.figure(); ax = fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum())
ticks = ax.set_xticks([0, 250, 500, 750, 1000])
labels = ax.set_xticklabels(['one', 'two', 'three', 'four', 'five'],
rotation=30, fontsize='small')
ax.set_title('My first matplotlib plot')
ax.set_xlabel('Stages')

png

添加图例

1
2
3
4
5
6
fig = plt.figure(); ax = fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum(), 'k', label='one')
ax.plot(randn(1000).cumsum(), 'k--', label='two')
ax.plot(randn(1000).cumsum(), 'k.', label='three')
ax.legend(loc='best')

png

注解以及在subplot上绘图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from datetime import datetime
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
data = pd.read_csv('ch08/spx.csv', index_col=0, parse_dates=True)
spx = data['SPX']
spx.plot(ax=ax, style='k-')
crisis_data = [
(datetime(2007, 10, 11), 'Peak of bull market'),
(datetime(2008, 3, 12), 'Bear Stearns Fails'),
(datetime(2008, 9, 15), 'Lehman Bankruptcy')
]
for date, label in crisis_data:
ax.annotate(label, xy=(date, spx.asof(date) + 50),
xytext=(date, spx.asof(date) + 200),
arrowprops=dict(facecolor='black'),
horizontalalignment='left', verticalalignment='top')
# Zoom in on 2007-2010
ax.set_xlim(['1/1/2007', '1/1/2011'])
ax.set_ylim([600, 1800])
ax.set_title('Important dates in 2008-2009 financial crisis')

png

1
2
3
4
5
6
7
8
9
10
11
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
rect = plt.Rectangle((0.2, 0.75), 0.4, 0.15, color='k', alpha=0.3)
circ = plt.Circle((0.7, 0.2), 0.15, color='b', alpha=0.3)
pgon = plt.Polygon([[0.15, 0.15], [0.35, 0.4], [0.2, 0.6]],
color='g', alpha=0.5)
ax.add_patch(rect)
ax.add_patch(circ)
ax.add_patch(pgon)

png

将图表保存到文件

1
fig

png

1
fig.savefig('figpath.svg')
1
fig.savefig('figpath.png', dpi=400, bbox_inches='tight')
1
2
3
4
from io import BytesIO
buffer = BytesIO()
plt.savefig(buffer)
plot_data = buffer.getvalue()
<matplotlib.figure.Figure at 0xaebe550>

matplotlib 配置

1
plt.rc('figure', figsize=(10, 10))

pandas中的绘图函数

线型图

1
plt.close('all')
1
2
s = Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s.plot()

png

1
2
3
4
df = DataFrame(np.random.randn(10, 4).cumsum(0),
columns=['A', 'B', 'C', 'D'],
index=np.arange(0, 100, 10))
df.plot()

png

柱状图

1
2
3
4
fig, axes = plt.subplots(2, 1)
data = Series(np.random.rand(16), index=list('abcdefghijklmnop'))
data.plot(kind='bar', ax=axes[0], color='k', alpha=0.7)
data.plot(kind='barh', ax=axes[1], color='k', alpha=0.7)

png

1
2
3
4
5
df = DataFrame(np.random.rand(6, 4),
index=['one', 'two', 'three', 'four', 'five', 'six'],
columns=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
df
df.plot(kind='bar')



























































GenusABCD
one0.3016860.1563330.3719430.270731
two0.7505890.5255870.6894290.358974
three0.3815040.6677070.4737720.632528
four0.9424080.1801860.7082840.641783
five0.8402780.9095890.0100410.653207
six0.0628540.5898130.8113180.060217

png

1
plt.figure()
1
df.plot(kind='barh', stacked=True, alpha=0.5)

png

1
2
3
4
5
6
tips = pd.read_csv('ch08/tips.csv')
party_counts = pd.crosstab(tips.day, tips.size_)
party_counts
# Not many 1- and 6-person parties
party_counts = party_counts.ix[:, 2:5]
party_counts
































































size_123456
day
Fri1161100
Sat253181310
Sun039151831
Thur1484513





















































size_2345
day
Fri16110
Sat5318131
Sun3915183
Thur48451

1
2
3
4
5
# Normalize to sum to 1
party_pcts = party_counts.div(party_counts.sum(1).astype(float), axis=0)
party_pcts
party_pcts.plot(kind='bar', stacked=True)




















































size_2345
day
Fri0.8888890.0555560.0555560.000000
Sat0.6235290.2117650.1529410.011765
Sun0.5200000.2000000.2400000.040000
Thur0.8275860.0689660.0862070.017241

png

直方图和密度图

1
plt.figure()
1
2
tips['tip_pct'] = tips['tip'] / tips['total_bill']
tips['tip_pct'].hist(bins=50)

png

1
plt.figure()
1
tips['tip_pct'].plot(kind='kde')

png

1
plt.figure()
1
2
3
4
5
comp1 = np.random.normal(0, 1, size=200) # N(0, 1)
comp2 = np.random.normal(10, 2, size=200) # N(10, 4)
values = Series(np.concatenate([comp1, comp2]))
values.hist(bins=100, alpha=0.3, color='k', normed=True)
values.plot(kind='kde', style='k--')

png

散点图

1
2
3
4
macro = pd.read_csv('ch08/macrodata.csv')
data = macro[['cpi', 'm1', 'tbilrate', 'unemp']]
trans_data = np.log(data).diff().dropna()
trans_data[-5:]




















































cpim1tbilrateunemp
198-0.0079040.045361-0.3968810.105361
199-0.0219790.066753-2.2772670.139762
2000.0023400.0102860.6061360.160343
2010.0084190.037461-0.2006710.127339
2020.0088940.012202-0.4054650.042560

1
plt.figure()
1
2
plt.scatter(trans_data['m1'], trans_data['unemp'])
plt.title('Changes in log %s vs. log %s' % ('m1', 'unemp'))

png

1
pd.scatter_matrix(trans_data, diagonal='kde', c='k', alpha=0.3)

png

绘制地图:图形化显示海底地震危机数据

1
2
data = pd.read_csv('ch08/Haiti.csv')
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3593 entries, 0 to 3592
Data columns (total 10 columns):
Serial            3593 non-null int64
INCIDENT TITLE    3593 non-null object
INCIDENT DATE     3593 non-null object
LOCATION          3592 non-null object
DESCRIPTION       3593 non-null object
CATEGORY          3587 non-null object
LATITUDE          3593 non-null float64
LONGITUDE         3593 non-null float64
APPROVED          3593 non-null object
VERIFIED          3593 non-null object
dtypes: float64(2), int64(1), object(7)
memory usage: 280.8+ KB
1
data[['INCIDENT DATE', 'LATITUDE', 'LONGITUDE']][:10]












































































INCIDENT DATELATITUDELONGITUDE
005/07/2010 17:2618.233333-72.533333
128/06/2010 23:0650.2260295.729886
224/06/2010 16:2122.278381114.174287
320/06/2010 21:5944.4070628.933989
418/05/2010 16:2618.571084-72.334671
526/04/2010 13:1418.593707-72.310079
626/04/2010 14:1918.482800-73.638800
726/04/2010 14:2718.415000-73.195000
815/03/2010 10:5818.517443-72.236841
915/03/2010 11:0018.547790-72.410010

1
data['CATEGORY'][:6]
0          1. Urgences | Emergency, 3. Public Health, 
1    1. Urgences | Emergency, 2. Urgences logistiqu...
2    2. Urgences logistiques | Vital Lines, 8. Autr...
3                            1. Urgences | Emergency, 
4                            1. Urgences | Emergency, 
5                       5e. Communication lines down, 
Name: CATEGORY, dtype: object
1
data.describe()
































































SerialLATITUDELONGITUDE
count3593.0000003593.0000003593.000000
mean2080.27748418.611495-72.322680
std1171.1003600.7385723.650776
min4.00000018.041313-74.452757
25%1074.00000018.524070-72.417500
50%2163.00000018.539269-72.335000
75%3088.00000018.561820-72.293570
max4052.00000050.226029114.174287

1
2
3
data = data[(data.LATITUDE > 18) & (data.LATITUDE < 20) &
(data.LONGITUDE > -75) & (data.LONGITUDE < -70)
& data.CATEGORY.notnull()]
1
2
3
4
5
6
7
8
9
10
11
12
13
def to_cat_list(catstr):
stripped = (x.strip() for x in catstr.split(','))
return [x for x in stripped if x]
def get_all_categories(cat_series):
cat_sets = (set(to_cat_list(x)) for x in cat_series)
return sorted(set.union(*cat_sets))
def get_english(cat):
code, names = cat.split('.')
if '|' in names:
names = names.split(' | ')[1]
return code, names.strip()
1
get_english('2. Urgences logistiques | Vital Lines')
('2', 'Vital Lines')
1
2
3
4
5
all_cats = get_all_categories(data.CATEGORY)
# Generator expression
english_mapping = dict(get_english(x) for x in all_cats)
english_mapping['2a']
english_mapping['6c']
'Food Shortage'






'Earthquake and aftershocks'
1
2
3
4
5
6
7
def get_code(seq):
return [x.split('.')[0] for x in seq if x]
all_codes = get_code(all_cats)
code_index = pd.Index(np.unique(all_codes))
dummy_frame = DataFrame(np.zeros((len(data), len(code_index))),
index=data.index, columns=code_index)
1
dummy_frame.ix[:, :6].info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3569 entries, 0 to 3592
Data columns (total 6 columns):
1     3569 non-null float64
1a    3569 non-null float64
1b    3569 non-null float64
1c    3569 non-null float64
1d    3569 non-null float64
2     3569 non-null float64
dtypes: float64(6)
memory usage: 195.2 KB
1
2
3
4
5
for row, cat in zip(data.index, data.CATEGORY):
codes = get_code(to_cat_list(cat))
dummy_frame.ix[row, codes] = 1
data = data.join(dummy_frame.add_prefix('category_'))
1
data.ix[:, 10:15].info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3569 entries, 0 to 3592
Data columns (total 5 columns):
category_1     3569 non-null float64
category_1a    3569 non-null float64
category_1b    3569 non-null float64
category_1c    3569 non-null float64
category_1d    3569 non-null float64
dtypes: float64(5)
memory usage: 167.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
def basic_haiti_map(ax=None, lllat=17.25, urlat=20.25,
lllon=-75, urlon=-71):
# create polar stereographic Basemap instance.
m = Basemap(ax=ax, projection='stere',
lon_0=(urlon + lllon) / 2,
lat_0=(urlat + lllat) / 2,
llcrnrlat=lllat, urcrnrlat=urlat,
llcrnrlon=lllon, urcrnrlon=urlon,
resolution='f')
# draw coastlines, state and country boundaries, edge of map.
m.drawcoastlines()
m.drawstates()
m.drawcountries()
return m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 10))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
to_plot = ['2a', '1', '3c', '7a']
lllat=17.25; urlat=20.25; lllon=-75; urlon=-71
for code, ax in zip(to_plot, axes.flat):
m = basic_haiti_map(ax, lllat=lllat, urlat=urlat,
lllon=lllon, urlon=urlon)
cat_data = data[data['category_%s' % code] == 1]
# compute map proj coordinates.
x, y = m(cat_data.LONGITUDE.values, cat_data.LATITUDE.values)
m.plot(x, y, 'k.', alpha=0.5)
ax.set_title('%s: %s' % (code, english_mapping[code]))
C:\Users\Ewan\Anaconda3\envs\ipykernel_py2\lib\site-packages\mpl_toolkits\basemap\__init__.py:3260: MatplotlibDeprecationWarning: The ishold function was deprecated in version 2.0.
  b = ax.ishold()
C:\Users\Ewan\Anaconda3\envs\ipykernel_py2\lib\site-packages\mpl_toolkits\basemap\__init__.py:3269: MatplotlibDeprecationWarning: axes.hold is deprecated.
    See the API Changes document (http://matplotlib.org/api/api_changes.html)
    for more details.
  ax.hold(b)

png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#街道数据的路径
shapefilepath = 'ch08/PortAuPrince_Roads/PortAuPrince_Roads'
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
lat0 = 18.533333;lon0 = -72.333333;change = 0.13;
lllat=lat0-change; urlat=lat0+change; lllon=lon0-change; urlon=lon0+change;
m = basic_haiti_map(ax, lllat=lllat, urlat=urlat,lllon=lllon, urlon=urlon)
m.readshapefile(shapefilepath,'roads') #添加街道数据
code = '2a'
cat_data = data[data['category_%s' % code] == 1]
# compute map proj coordinates.
x, y = m(cat_data.LONGITUDE.values, cat_data.LATITUDE.values)
m.plot(x, y, 'k.', alpha=0.5)
ax.set_title('Food shortages reported in Port-au-Prince')
# plt.savefig('myfig.png',dpi=400,bbox_inches='tight')
(1583,
 3,
 [-72.749246, 18.409952, 0.0, 0.0],
 [-71.973789, 18.7147105, 0.0, 0.0],

png