Abracadabra

Do it yourself


  • Home

  • Categories

  • About

  • Archives

  • Tags

  • Sitemap

  • 公益404

  • Search
close
Abracadabra

Python data analysis-Learning note-Ch02

Posted on 2017-02-19 | | Visitors

利用Python内置的JSON模块对数据进行解析并转化为字典

数据如下:

1
2
3
4
5
6
'{ "a": "Mozilla\\/5.0 (Windows NT 6.1; WOW64) AppleWebKit\\/535.11 (KHTML, like Gecko)
Chrome\\/17.0.963.78 Safari\\/535.11", "c": "US", "nk": 1, "tz": "America\\/New_York", "gr":
"MA", "g": "A6qOVH", "h": "wfLQtf", "l": "orofrog", "al": "en-US,en;q=0.8", "hh": "1.usa.gov",
"r": "http:\\/\\/www.facebook.com\\/l\\/7AQEFzjSi\\/1.usa.gov\\/wfLQtf", "u":
"http:\\/\\/www.ncbi.nlm.nih.gov\\/pubmed\\/22415991", "t": 1331923247, "hc": 1331822918,
"cy": "Danvers", "ll": [ 42.576698, -70.954903 ] }\n'

核心代码:

1
2
3
import json
path = 'ch02/usagov_bitly_data2012-03-16-1331923249.txt'
records = [json.loads(line) for line in open(path)]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
records[0]
-----------------------------------
{'a': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/535.11 (KHTML, like Gecko) Chrome/17.0.963.78 Safari/535.11',
'al': 'en-US,en;q=0.8',
'c': 'US',
'cy': 'Danvers',
'g': 'A6qOVH',
'gr': 'MA',
'h': 'wfLQtf',
'hc': 1331822918,
'hh': '1.usa.gov',
'l': 'orofrog',
'll': [42.576698, -70.954903],
'nk': 1,
'r': 'http://www.facebook.com/l/7AQEFzjSi/1.usa.gov/wfLQtf',
't': 1331923247,
'tz': 'America/New_York',
'u': 'http://www.ncbi.nlm.nih.gov/pubmed/22415991'}

对时区字段进行计数(pure python vs. pandas)

首先从记录中提取时区字段并且放入一个列表中

1
time_zones = [rec['tz'] for rec in records if 'tz' in rec]
1
2
3
4
5
6
7
8
9
10
11
12
time_zones[:10]
-----------------------------------
['America/New_York',
'America/Denver',
'America/New_York',
'America/Sao_Paulo',
'America/New_York',
'America/New_York',
'Europe/Warsaw',
'',
'',
'']

使用纯粹的python进行计数

1
2
3
4
5
6
7
8
def get_counts(sequence):
counts = {}
for x in sequence:
if x in counts:
counts[x] += 1
else:
counts[x] = 1
return counts

使用下列方法更加简洁

1
2
3
4
5
6
7
from collections import defaultdict
def get_counts2(sequence):
counts = defaultdict(int) # values will initialize to 0
for x in sequence:
counts[x] += 1
return counts

如果需要返回前十位的时区及其计数值

1
2
3
4
def top_counts(count_dict, n=10):
value_key_pairs = [(count, tz) for tz, count in count_dict.items()]
value_key_pairs.sort()
return value_key_pairs[-n:]
1
2
3
4
5
6
7
8
9
10
11
12
top_counts(counts)
--------------------------------------
[(33, 'America/Sao_Paulo'),
(35, 'Europe/Madrid'),
(36, 'Pacific/Honolulu'),
(37, 'Asia/Tokyo'),
(74, 'Europe/London'),
(191, 'America/Denver'),
(382, 'America/Los_Angeles'),
(400, 'America/Chicago'),
(521, ''),
(1251, 'America/New_York')]

可以使用python自带的库

1
from collections import Counter
1
counts = Counter(time_zones)
1
2
3
4
5
6
7
8
9
10
11
12
counts.most_common(10)
--------------------------------
[('America/New_York', 1251),
('', 521),
('America/Chicago', 400),
('America/Los_Angeles', 382),
('America/Denver', 191),
('Europe/London', 74),
('Asia/Tokyo', 37),
('Pacific/Honolulu', 36),
('Europe/Madrid', 35),
('America/Sao_Paulo', 33)]

使用pandas进行相同的任务

pandas中主要的数据结构是DataFrame, 作用是将数据表示成表格

1
2
3
4
5
from pandas import DataFrame, Series
import pandas as pd
frame = DataFrame(records)
frame

dataframe_data_repr

1
2
3
4
5
6
7
8
9
10
11
12
13
frame['tz'][:10]
-------------------------------
0 America/New_York
1 America/Denver
2 America/New_York
3 America/Sao_Paulo
4 America/New_York
5 America/New_York
6 Europe/Warsaw
7
8
9
Name: tz, dtype: object

计数·

1
2
3
4
5
6
7
8
9
10
11
12
13
14
tz_counts = frame['tz'].value_counts()
tz_counts[:10]
--------------------------------------------------
America/New_York 1251
521
America/Chicago 400
America/Los_Angeles 382
America/Denver 191
Europe/London 74
Asia/Tokyo 37
Pacific/Honolulu 36
Europe/Madrid 35
America/Sao_Paulo 33
Name: tz, dtype: int64

填补缺失值以及未知值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
clean_tz = frame['tz'].fillna('Missing')
clean_tz[clean_tz == ''] = 'Unknown'
tz_counts = clean_tz.value_counts()
tz_counts[:10]
----------------------------------------------
America/New_York 1251
Unknown 521
America/Chicago 400
America/Los_Angeles 382
America/Denver 191
Missing 120
Europe/London 74
Asia/Tokyo 37
Pacific/Honolulu 36
Europe/Madrid 35
Name: tz, dtype: int64

画个图展示一下

1
2
plt.figure(figsize=(10, 4))
tz_counts[:10].plot(kind='barh', rot=0)

majority_tz

下面我们对用户使用的浏览器的信息做一些操作

Series应该代表的是DataFrame中的一列

1
2
3
4
5
6
7
8
9
results = Series([x.split()[0] for x in frame.a.dropna()])
results[:5]
---------------------------------------------------
0 Mozilla/5.0
1 GoogleMaps/RochesterNY
2 Mozilla/4.0
3 Mozilla/5.0
4 Mozilla/5.0
dtype: object

同样可以进行计数

1
2
3
4
5
6
7
8
9
10
11
results.value_counts()[:8]
-----------------------------------------
Mozilla/5.0 2594
Mozilla/4.0 601
GoogleMaps/RochesterNY 121
Opera/9.80 34
TEST_INTERNET_AGENT 24
GoogleProducer 21
Mozilla/6.0 5
BlackBerry8520/5.0.0.681 4
dtype: int64

根据Windows和Non-Windows用户进行时区的分组操作

1
cframe = frame[frame.a.notnull()]
1
2
3
4
5
6
operating_system = np.where(cframe['a'].str.contains('Windows'),
'Windows', 'Not Windows')
operating_system[:5]
-----------------------------------------------------------------
array(['Windows', 'Not Windows', 'Windows', 'Not Windows', 'Windows'],
dtype='<U11')
1
by_tz_os = cframe.groupby(['tz', operating_system])

来看看这个by_tz_os长什么样

1
by_tz_os.size()

pandas_group_by_data_pic

再来看看unstack()的炫酷效果

pandas_group_by_data_unstack

排下序, 看看排名多少

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Use to sort in ascending order
indexer = agg_counts.sum(1).argsort()
indexer[:10]
------------------------------------------------
tz
24
Africa/Cairo 20
Africa/Casablanca 21
Africa/Ceuta 92
Africa/Johannesburg 87
Africa/Lusaka 53
America/Anchorage 54
America/Argentina/Buenos_Aires 57
America/Argentina/Cordoba 26
America/Argentina/Mendoza 55
dtype: int64

取出前十的来看看

1
2
count_subset = agg_counts.take(indexer)[-10:]
count_subset

pandas_group_by_data_sort_top10

同样画个图

1
count_subset.plot(kind='barh', stacked=True)

pandas_group_by_data_sort_top10_pic

看看两个类别所占的比例是多少

1
2
normed_subset = count_subset.div(count_subset.sum(1), axis=0)
normed_subset.plot(kind='barh', stacked=True)

pandas_group_by_data_sort_top10_percent

电影评分数据表连接操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import pandas as pd
import os
encoding = 'latin1'
upath = os.path.expanduser('ch02/movielens/users.dat')
rpath = os.path.expanduser('ch02/movielens/ratings.dat')
mpath = os.path.expanduser('ch02/movielens/movies.dat')
unames = ['user_id', 'gender', 'age', 'occupation', 'zip']
rnames = ['user_id', 'movie_id', 'rating', 'timestamp']
mnames = ['movie_id', 'title', 'genres']
users = pd.read_csv(upath, sep='::', header=None, names=unames, encoding=encoding)
ratings = pd.read_csv(rpath, sep='::', header=None, names=rnames, encoding=encoding)
movies = pd.read_csv(mpath, sep='::', header=None, names=mnames, encoding=encoding)

看看数据长什么样

1
users[:5]

pandas_ch02_users

1
ratings[:5]

pandas_ch02_ratings

1
movies[:5]

pandas_ch02_movies

多表连接

1
2
data = pd.merge(pd.merge(ratings, users), movies)
data

pandas_ch02_multi_table_joint

1
2
3
4
5
6
7
8
9
10
11
12
13
data.ix[0]
--------------------------------------------
user_id 1
movie_id 1193
rating 5
timestamp 978300760
gender F
age 1
occupation 10
zip 48067
title One Flew Over the Cuckoo's Nest (1975)
genres Drama
Name: 0, dtype: object

根据性别计算每部电影的平均评分

1
2
3
mean_ratings = data.pivot_table('rating', index='title',
columns='gender', aggfunc='mean')
mean_ratings[:5]

pandas_ch02_movie_avg_score_by_gender

过滤掉评分数小于250的电影

1
ratings_by_title = data.groupby('title').size()
1
2
3
4
5
6
7
8
9
ratings_by_title[:5]
-----------------------------------------
title
$1,000,000 Duck (1971) 37
'Night Mother (1986) 70
'Til There Was You (1997) 52
'burbs, The (1989) 303
...And Justice for All (1979) 199
dtype: int64
1
active_titles = ratings_by_title.index[ratings_by_title >= 250]
1
2
3
4
5
6
7
8
active_titles[:10]
-----------------------------------------
Index([''burbs, The (1989)', '10 Things I Hate About You (1999)',
'101 Dalmatians (1961)', '101 Dalmatians (1996)', '12 Angry Men (1957)',
'13th Warrior, The (1999)', '2 Days in the Valley (1996)',
'20,000 Leagues Under the Sea (1954)', '2001: A Space Odyssey (1968)',
'2010 (1984)'],
dtype='object', name='title')

ix应该是一个交集操作

1
2
mean_ratings = mean_ratings.ix[active_titles]
mean_ratings

pandas_ch02_movie_avg_score_by_gender_ratings_more_than_250

按照女性最喜欢的电影进行降序排序

1
2
top_female_ratings = mean_ratings.sort_values(by='F', ascending=False)
top_female_ratings[:10]

pandas_ch02_movie_female_favor_top_10

​

US Baby Names 1880-2010

1
2
3
import pandas as pd
names1880 = pd.read_csv('ch02/names/yob1880.txt', names=['name', 'sex', 'births'])
names1880

pandas_ch02_us_baby_name

把所有年份的数据合并一下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 2010 is the last available year right now
years = range(1880, 2011)
pieces = []
columns = ['name', 'sex', 'births']
for year in years:
path = 'ch02/names/yob%d.txt' % year
frame = pd.read_csv(path, names=columns)
frame['year'] = year
pieces.append(frame)
# Concatenate everything into a single DataFrame
names = pd.concat(pieces, ignore_index=True)

进行聚合操作

1
2
total_births = names.pivot_table('births', index='year',
columns='sex', aggfunc=sum)
1
total_births.tail()

pandas_ch02_us_baby_name_addition

计算一下每个名字的出生比例

1
2
3
4
5
6
7
def add_prop(group):
# Integer division floors
births = group.births.astype(float)
group['prop'] = births / births.sum()
return group
names = names.groupby(['year', 'sex']).apply(add_prop)
1
names

pandas_ch02_us_baby_name_prop

进行一下有效性检查

1
2
3
np.allclose(names.groupby(['year', 'sex']).prop.sum(), 1)
--------------------------------------------
True

筛选出每一对year/sex下总数前1000的名字

1
2
3
4
def get_top1000(group):
return group.sort_values(by='births', ascending=False)[:1000]
grouped = names.groupby(['year', 'sex'])
top1000 = grouped.apply(get_top1000)

加个索引,结合了numpy

1
top1000.index = np.arange(len(top1000))

Analyzing naming trends

将数据分为男女

1
2
boys = top1000[top1000.sex == 'M']
girls = top1000[top1000.sex == 'F']

计算每一年每个名字的出生总数

1
2
3
total_births = top1000.pivot_table('births', index='year', columns='name',
aggfunc=sum)
total_births

pandas_ch02_us_baby_name_counts_per_year

选出几个名字看看总数随年份的变化情况

1
2
3
subset = total_births[['John', 'Harry', 'Mary', 'Marilyn']]
subset.plot(subplots=True, figsize=(12, 10), grid=False,
title="Number of births per year")

pandas_ch02_us_baby_name_trend

Measuring the increase in naming diversity

通过统计前1000项名字所占的比例来判断多样性的变化

1
2
3
4
table = top1000.pivot_table('prop', index='year',
columns='sex', aggfunc=sum)
table.plot(title='Sum of table1000.prop by year and sex',
yticks=np.linspace(0, 1.2, 13), xticks=range(1880, 2020, 10))

pandas_ch02_us_baby_name_diversity

另一种方法,计算占出生人数50%的名字的数量

也即从开始累加,看加到第几个名字时所占比例为50%

先来看看2010年的男孩

1
df = boys[boys.year == 2010]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
prop_cumsum = df.sort_index(by='prop', ascending=False).prop.cumsum()
prop_cumsum[:10]
--------------------------------------------------
260877 0.011523
260878 0.020934
260879 0.029959
260880 0.038930
260881 0.047817
260882 0.056579
260883 0.065155
260884 0.073414
260885 0.081528
260886 0.089621
Name: prop, dtype: float64

看来是第116个,不过序号从0开始,应该是117

1
2
3
prop_cumsum.values.searchsorted(0.5)
---------------------------------------------------
116

再来看看1900年的男孩儿

1
2
3
4
5
df = boys[boys.year == 1900]
in1900 = df.sort_index(by='prop', ascending=False).prop.cumsum()
in1900.values.searchsorted(0.5) + 1
---------------------------------------------------
25

所以这样做是可行的

把相同的操作赋予整个数据集

1
2
3
4
5
6
def get_quantile_count(group, q=0.5):
group = group.sort_values(by='prop', ascending=False)
return group.prop.cumsum().values.searchsorted(q) + 1
diversity = top1000.groupby(['year', 'sex']).apply(get_quantile_count)
diversity = diversity.unstack('sex')
diversity.head()

pandas_ch02_us_baby_name_number_in_half_percent

1
diversity.plot(title="Number of popular names in top 50%")

pandas_ch02_us_baby_name_diversity_2

The “Last letter” Revolution

取出每个名字对应的最后一个字母,同时序号对应

1
2
3
4
5
6
7
# extract last letter from name column
get_last_letter = lambda x: x[-1]
last_letters = names.name.map(get_last_letter)
last_letters.name = 'last_letter'
table = names.pivot_table('births', index=last_letters,
columns=['sex', 'year'], aggfunc=sum)

单独取出三年的来看看

1
2
subtable = table.reindex(columns=[1910, 1960, 2010], level='year')
subtable.head()

pandas_ch02_us_baby_name_last_letter

计算一下字母比例

1
2
3
4
5
6
7
8
9
10
subtable.sum()
-------------------------------------
sex year
F 1910 396416.0
1960 2022062.0
2010 1759010.0
M 1910 194198.0
1960 2132588.0
2010 1898382.0
dtype: float64
1
letter_prop = subtable / subtable.sum().astype(float)
1
2
3
4
5
6
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
letter_prop['M'].plot(kind='bar', rot=0, ax=axes[0], title='Male')
letter_prop['F'].plot(kind='bar', rot=0, ax=axes[1], title='Female',
legend=False)

pandas_ch02_us_baby_name_last_letter_prop

最后看一下所有的年份并生成一个趋势图

1
2
letter_prop = table / table.sum().astype(float)
dny_ts = letter_prop.ix[['d', 'n', 'y'], 'M'].T
1
dny_ts.plot()

pandas_ch02_us_baby_name_last_letter_prop_trend

Boy names that became girl names (and vice versa)

以lesl开头的名字为例

1
2
3
4
5
6
all_names = top1000.name.unique()
mask = np.array(['lesl' in x.lower() for x in all_names])
lesley_like = all_names[mask]
lesley_like
----------------------------------------------
array(['Leslie', 'Lesley', 'Leslee', 'Lesli', 'Lesly'], dtype=object)

从原数据集中筛选出来

1
2
3
4
5
6
7
8
9
10
filtered = top1000[top1000.name.isin(lesley_like)]
filtered.groupby('name').births.sum()
----------------------------------------------
name
Leslee 1082
Lesley 35022
Lesli 929
Leslie 370429
Lesly 10067
Name: births, dtype: int64

做一下聚合操作并计算比例

1
2
3
4
table = filtered.pivot_table('births', index='year',
columns='sex', aggfunc='sum')
table = table.div(table.sum(1), axis=0)
table.tail()

pandas_ch02_us_baby_name_b2g_prop

看一下趋势

1
table.plot(style={'M': 'k-', 'F': 'k--'})

pandas_ch02_us_baby_name_b2g_prop_trend

Abracadabra

Spider the house infomation and save to excel file

Posted on 2017-02-18 | | Visitors

数据来源

http://sh.fang.com/

项目目标

爬取二手房信息中的小区信息

实现步骤

【1】爬取小区信息(核心代码,下同)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
""" <A spider to crawl the house information.>
Copyright (C) <2017> Li W.H., Duan X
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
class HouseSpider(scrapy.Spider):
name = "house"
head = "http://esf.sh.fang.com"
allowed_domains = ["sh.fang.com"]
start_urls = [
"http://esf.sh.fang.com/housing/"
]
# 各区对应的编号(由于丧心病狂的url)
area_map = {25: 'pudong', 18: 'minhang', 19: 'xuhui', 30: 'baoshan',
28: 'putuo', 20: 'changning', 26: 'yangpu', 586: 'songjiang',
29: 'jiading', 23: 'hongkou', 27: 'zhabei', 21: 'jingan',
24: 'huangpu', 22: 'luwan', 31: 'qingpu', 32: 'fengxian',
35: 'jinshan', 996: 'chongming'}
estate_to_area_map = {}
seperator = '=\n'
def __init__(self):
for key, value in self.area_map.items():
self.estate_to_area_map[key] = []
# print(self.estate_to_area_map)
def parse(self, response):
# 解析出上海市各区的地址
area_lis = response.xpath('//*[@id="houselist_B03_02"]/div[1]')
for a in area_lis.xpath('./a'):
# areas = items.AreaItem()
# areas['name'] = a.xpath('text()').extract()[0]
yield Request(self.head + a.xpath('@href').extract()[0],
callback=self.parse_area)
# print(a.xpath('text()').extract()[0])
def parse_area(self, response):
# 确定response来源于哪一个区
area_index = str(response).split('/')[-2].split('_')[0]
if area_index == '':
return
else:
# 解析出各区中小区的详情页面地址
detail_str = 'xiangqing'
estate_list = response.xpath('/html/body/div[4]/div[5]/div[4]')
for a in estate_list.xpath('.//a[@class="plotTit"]'):
estate_url = a.xpath('@href').extract()[0]
if estate_url.find('esf') != -1:
estate_url = estate_url.replace('esf', detail_str)
else:
estate_url = estate_url + detail_str
if estate_url.find('http') != -1:
# print(estate_url)
self.estate_to_area_map[int(area_index)].append(estate_url)
# print(len(self.estate_to_area_map[int(area_index)]))
next_page = response.xpath('//*[@id="PageControl1_hlk_next"]')
if len(next_page) != 0:
yield Request(self.head +
next_page.xpath('@href').extract()[0],
callback=self.parse_area)
else:
# print(len(self.estate_to_area_map[int(area_index)]))
for url in self.estate_to_area_map[int(area_index)]:
request = Request(url, callback=self.parse_house,
dont_filter=True)
request.meta['index'] = int(area_index)
yield request
def parse_house(self, response):
flag = 0
area_index = response.meta['index']
area_name = self.area_map[area_index]
filename = area_name + '.txt'
# print(response.xpath('/html'))
# 详情页面存在两种,因此分情况讨论
house_name = response.xpath(
'/html/body/div[4]/div[2]/div[2]/h1/a/text()')
if len(house_name) == 0:
# house_name = response.xpath(
# '/html/body/div[1]/div[3]/div[2]/h1/a/text()')
# flag = 1
return
house_name = house_name.extract()[0]
# 清洁小区名
house_name = re.sub(r'小区网', '', house_name)
result_str = '【小区名称】' + house_name + '\n'
if flag == 0:
avg_price_xpath = response.xpath(
'/html/body/div[4]/div[4]/div[1]/div[1]/dl[1]/dd/span/text()')
avg_price = avg_price_xpath.extract()[0]
result_str = result_str + '【平均价格】' + avg_price + '\n'
detail_block_list = response.xpath(
'/html/body/div[4]/div[4]/div[1]')
for headline in detail_block_list.xpath('.//h3'):
head_str = headline.xpath('./text()').extract()[0]
if head_str == '基本信息':
result_str = result_str + \
'【' + \
head_str + '】\n'
for item in headline.xpath(
'../../div[@class="inforwrap clearfix"]/dl/dd'):
if len(item.xpath('./strong/text()')) != 0:
if len(item.xpath('./text()')) != 0:
result_str = result_str + \
item.xpath(
'./strong/text()').extract()[0]
result_str = result_str + \
item.xpath('./text()').extract()[0] + '\n'
# print(result_str)
# elif head_str == '交通状况':
# result_str = result_str + \
# '【' + \
# head_str + '】\n'
# tempstr = headline.xpath(
# '../../div[@class="inforwrap clearfix"]/dl/dt/text()').extract()[0]
# result_str = result_str + tempstr + '\n'
# # print(result_str)
# elif head_str == '周边信息':
# result_str = result_str + \
# '【' + \
# head_str + '】\n'
# for item in headline.xpath(
# '../../div[@class="inforwrap clearfix"]/dl/dt'):
# result_str = result_str + \
# item.xpath('./text()').extract()[0] + '\n'
# # print(result_str)
elif head_str == '就近楼群':
result_str = result_str + \
'【' + \
head_str + '】\n'
for item in headline.xpath(
'../../div[@class="inforwrap clearfix"]/dl/dd'):
result_str = result_str + \
item.xpath('./a/text()').extract()[0] + '\n'
result_str = result_str + self.seperator
# print(result_str)
with open(filename, 'a', errors='ignore') as f:
f.write(result_str)

【2】格式化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
""" <A formatter>
Copyright (C) <2017> Li W.H., Duan X
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
def GetDataFileList(path='.'):
""" Get the houses file list.
Arguments:
path: Dir path.
Returns:
file_list: the list of data file that find houses data.
"""
file_list = [x for x in os.listdir(path) if os.path.isfile(
x) and os.path.splitext(x)[1] == '.txt']
return file_list
def Parse(file_list):
""" Parse the txt file that find houses data.
Extract some import infomation such as house name,
avarage price, address and so on.
Arguments:
file_list: the list of data file that find houses data.
Returns:
houses_dict_list: the list that each item find the detail
dict of each house.
"""
HOUSE_NAME = '小区名称'
HOUSE_NAME_SPLITOR = '】'
HOUSE_ADDRESS = '小区地址'
HOUSE_ADDRESS_SPLITOR = ':'
HOUSE_AVG_PRICE = '平均价格'
HOUSE_AVG_PRICE_SPLITOR = '】'
AREA_OF_HOUSE_BELONGS_TO = '所属区域'
AREA_OF_HOUSE_BELONGS_TO_SPLITOR_1 = ':'
AREA_OF_HOUSE_BELONGS_TO_SPLITOR_2 = ' '
PROPERTY_CATEGORY = '物业类别'
PROPERTY_CATEGORY_SPLITOR = ':'
GREEN_RATE = '绿 化 率'
GREEN_RATE_SPLITOR = ':'
VOLUME_RATE = '容 积 率'
VOLUME_RATE_SPLITOR = ':'
PROPERTY_COSTS = '物 业 费'
PROPERTY_COSTS_SPLITOR = ':'
NO_INFO_NOW = '暂无信息'
DETAIL_LIST = [HOUSE_NAME, HOUSE_AVG_PRICE, HOUSE_ADDRESS, AREA_OF_HOUSE_BELONGS_TO,
PROPERTY_CATEGORY, GREEN_RATE, VOLUME_RATE, PROPERTY_COSTS]
houses_dict_list = []
for file_name in file_list:
raw_houses_string = ''
# read all lines as a string
with open(file_name, 'r', errors='ignore') as f:
for line in f.readlines():
raw_houses_string += line
# split the string to the houses raw info list
raw_houses_list = raw_houses_string.split('=\n')
raw_houses_details_list = []
for raw_house in raw_houses_list:
# format house raw info to lines
raw_houses_details = raw_house.split('\n')[:-1]
if len(raw_houses_details) == 0:
continue
# combine the all formated house raw info to a list
raw_houses_details_list.append(raw_houses_details)
for raw_house_details in raw_houses_details_list:
house_details_dict = {}
for raw_detail in raw_house_details:
# search house name
if re.search(HOUSE_NAME, raw_detail):
house_details_dict[HOUSE_NAME] = raw_detail.split(
HOUSE_NAME_SPLITOR)[-1]
# search house avarage price
elif re.search(HOUSE_AVG_PRICE, raw_detail):
# print(raw_detail)
house_details_dict[HOUSE_AVG_PRICE] = raw_detail.split(
HOUSE_AVG_PRICE_SPLITOR)[-1]
# search house address
elif re.search(HOUSE_ADDRESS, raw_detail):
house_details_dict[HOUSE_ADDRESS] = raw_detail.split(
HOUSE_ADDRESS_SPLITOR)[-1]
# search the area of house belongs to
elif re.search(AREA_OF_HOUSE_BELONGS_TO, raw_detail):
temp_detail_value = raw_detail.split(
AREA_OF_HOUSE_BELONGS_TO_SPLITOR_1)[-1]
detail_value = temp_detail_value.split(
AREA_OF_HOUSE_BELONGS_TO_SPLITOR_2)[0]
house_details_dict[AREA_OF_HOUSE_BELONGS_TO] = detail_value
# search the property category of house
elif re.search(PROPERTY_CATEGORY, raw_detail):
house_details_dict[PROPERTY_CATEGORY] = raw_detail.split(
PROPERTY_CATEGORY_SPLITOR)[-1]
# search the green rate
elif re.search(GREEN_RATE, raw_detail):
house_details_dict[GREEN_RATE] = raw_detail.split(
GREEN_RATE_SPLITOR)[-1]
# search the volume rate
elif re.search(VOLUME_RATE, raw_detail):
house_details_dict[VOLUME_RATE] = raw_detail.split(
VOLUME_RATE_SPLITOR)[-1]
# search the property costs
elif re.search(PROPERTY_COSTS, raw_detail):
house_details_dict[PROPERTY_COSTS] = raw_detail.split(
PROPERTY_COSTS_SPLITOR)[-1]
# Judge if all details are contained.
# If not, set to null.
house_details_dict_keys = house_details_dict.keys()
for detail_name in DETAIL_LIST:
if detail_name not in house_details_dict_keys:
house_details_dict[detail_name] = NO_INFO_NOW
houses_dict_list.append(house_details_dict)
return houses_dict_list

【3】通过高德地图api获取经纬度信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
""" <A toolto transfer position.>
Copyright (C) <2017> Li W.H., Duan X
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
def Geocode(address):
""" A tool that call the God-Map api.
Arguments:
address: the address to transfer.
Returns:
location: the transfered location.
"""
CITY_NAME = '上海'
parameters = {'address': address,
'key': 'your key',
'city': CITY_NAME}
base = 'http://restapi.amap.com/v3/geocode/geo'
try:
response = requests.get(base, parameters)
except Exception as e:
print('error!', e)
finally:
pass
answer = response.json()
return answer
def GETGodMapLocation(houses):
""" Get the location that corresponds to the house name.
Use the God-Map api to get the corresponding location.
Arguments:
houses_dict_list: the houses info.
Returns:
houses_dict_list_contains_loc: the houses info that
contains the location info.
"""
HOUSE_NAME = '小区名称'
HOUSE_LOCATION = '经纬度'
NO_INFO_NOW = '暂无信息'
houses_dict_list = houses.copy()
error_count = 0
count = 0
size = len(houses)
for house_dict in houses_dict_list:
# Count
count = count + 1
# Loading needs
if count % 1000 == 0:
print(count, '/', size)
address = house_dict[HOUSE_NAME]
answer = Geocode(address)
# print(answer)
# If find
if len(answer['geocodes']) != 0:
# print(address + "的经纬度:", answer['geocodes'][0]['location'])
house_dict[HOUSE_LOCATION] = answer['geocodes'][0]['location']
else:
# remaking the invalid address
# print('address remaking...')
if re.search(r'别墅', address):
re.sub(r'别墅', '', address)
else:
address = address + '小区'
# print('retransfering...')
# transfer again
answer = Geocode(address)
if len(answer['geocodes']) != 0:
# print(address + "的经纬度:", answer['geocodes'][0]['location'])
house_dict[HOUSE_LOCATION] = answer['geocodes'][0]['location']
else:
# print(address)
error_count += 1
house_dict[HOUSE_LOCATION] = NO_INFO_NOW
print('error counts: ', error_count)
return houses_dict_list

【4】存储成excel文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
""" <A tool to save the excel file.>
Copyright (C) <2017> Li W.H., Duan X
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
def Save2ExcelFile(houses):
""" Save the python based list file to excel file.
Arguments:
houses: the houses list.
"""
houses_dict_list = houses.copy()
house_list = []
# format the source data to fit the xlwt package
keys = houses[0].keys()
for key in keys:
house = []
house.append(key)
for house_dict in houses_dict_list:
house.append(house_dict[key])
house_list.append(house)
# return house_list
xls = ExcelWrite.Workbook()
sheet = xls.add_sheet('小区信息')
for i in range(len(house_list)):
for j in range(len(house_list[0])):
sheet.write(j, i, house_list[i][j])
xls.save('houses.xls')

结果展示

spider_house_txt_result

spider_house_excel_result

Abracadabra

python_god_web_api

Posted on 2017-02-17 | | Visitors
  1. http://lbs.amap.com/api/webservice/guide/api/search
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#!/usr/bin/env python3
#-*- coding:utf-8 -*-
'''
利用高德地图api实现地址和经纬度的转换
'''
import requests
def geocode(address):
parameters = {'address': address, 'key': 'e798a5bfb344a09977b79552ae415974'}
base = 'http://restapi.amap.com/v3/geocode/geo'
response = requests.get(base, parameters)
answer = response.json()
print(address + "的经纬度:", answer['geocodes'][0]['location'])
if __name__=='__main__':
#address = input("请输入地址:")
address = '北京市海淀区'
geocode(address)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import xlrd
def readXlsx(self, filename='CenterBottom2013.xlsx', sheetname='Sheet1'):
rawData = []
if (os.path.isfile(self.fn_rawDat)):
with open(self.fn_rawDat, 'rb') as f:
self.rawDat = np.load(f)
else:
workBook = xlrd.open_workbook(filename)
bookSheet = workBook.sheet_by_name(sheetname)
# 从第二行开始读取,因为第一行有标签
for row in range(1, bookSheet.nrows):
rowData = []
for col in range(bookSheet.ncols):
cel = bookSheet.cell(row, col)
try:
val = cel.value
except:
pass
if type(val) == float:
val = float(val)
else:
val = str(val)
rowData.append(val)
rawData.append(rowData)
self.rawDat = np.array(rawData)
with open(self.fn_rawDat, 'wb') as f:
np.save(f, self.rawDat)
return self.rawDat
  1. Read Excel files
  2. Transfer the address to locaion info
  3. Put back
Abracadabra

kNN and kd-tree

Posted on 2017-02-15 | | Visitors

k-近邻算法

工作原理

存在一组带标签的训练集[1],每当有新的不带标签的样本[2]出现时,将训练集中数据的特征与测试集的特征逐个比较,通过某种测度来提取出与测试集最相似的k个训练集样本,然后将这k个样本中占大多数[4]的标签赋予测试集样本。

伪代码

对测试集中的每个点依次执行如下操作:

  1. 计算训练集中的每个点与当前点的距离

  2. 按照距离递增次序排序

  3. 在排序好的点中选取前k个点

  4. 统计出k个点中不同类别的出现频率

  5. 选择频率最高的类别为当前点的预测分类

    ​

代码实现

首先创建测试数据集

1
from numpy import *
1
2
3
4
def createDataset():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels

返回预测分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndices = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# the return of sorted() is a list and its item is a tuple
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0] # returns the predict class label

进一步探索

k-近邻算法的缺点在于当数据量很大时,拥有不可接受的空间复杂度以及时间复杂度

其次该算法最关键的地方在与超参k的选取。当k选取的过小时容易造成过拟合,反之容易造成欠拟合。考虑两个极端情况,当k=1时,该算法又叫最近邻算法;当k=N[3]时,表示直接从原始数据中选取占比最大的类别,显然这个算法太naive了。

为了解决kNN算法时间复杂度的问题,最关键的便是在于如何对数据进行快速的k近邻搜索,一种解决方法是引入kd树来进行加速。

kd树

简介

以二维空间为例,假设有6个二维数据点{(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)},可以用下图来表明kd树所能达到的效果:

kd-overview

kd树算法主要分为两个部分:

  1. kd树数据结构的建立
  2. 在kd树上进行查找

kd树是一种对k维空间上的数据点进行存储以便进行高效查找的树形数据结构,属于二叉树。构造kd树相当于不断用垂直于坐标轴的超平面对k维空间进行划分,构成一系列k维超矩形区域。kd树的每一个结点对应一个超矩形区域,表示一个空间范围。

数据结构

下表给出每个结点主要包含的数据结构:

域名数据类型描述
Node-data数据矢量数据集中的某个数据点,k维矢量
Split整数垂直于分割超平面的方向轴序号
Leftkd树由位于该节点分割超平面左子空间内所有数据点构成的kd树
Rightkd树由位于该节点分割超平面右子空间内所有数据点构成的kd树
建立树伪代码

下面给出构建kd树的伪代码:

算法:构建k-d树(createKDTree)
输入:数据点集Data-set
输出:Kd,类型为k-d tree
1. If Data-set为空,则返回空的k-d tree
2. 调用节点生成程序: (1)确定split域:对于所有描述子数据(特征矢量),统计它们在每个维上的数据方差。以SURF特征为例,描述子为64维,可计算64个方差。挑选出最大值,对应的维就是split域的值。数据方差大表明沿该坐标轴方向上的数据分散得比较开,在这个方向上进行数据分割有较好的分辨率; (2)确定Node-data域:数据点集Data-set按其第split域的值排序。位于正中间的那个数据点被选为Node-data。此时新的Data-set’ = Data-set\Node-data(除去其中Node-data这一点)。
3. dataleft = {d属于Data-set’ && d[split] ≤ Node-data[split]} dataright = {d属于Data-set’ && d[split] > Node-data[split]}
4. left = 由(dataleft)建立的k-d tree,即递归调用createKDTree(dataleft)并设置left的parent域为Kd; right = 由(dataright)建立的k-d tree,即调用createKDTree(dataleft)并设置right的parent域为Kd。
实例

用最开始的6个二维数据点的例子,来具体化这个过程:

  1. 确定split域的首先该取的值。分别计算x,y方向上数据的方差得知x方向上的方差最大,所以split域值首先取0,也就是x轴方向;

  2. 确定Node-data的域值。根据x轴方向的值2,5,9,4,8,7排序选出中值为7,所以Node-data = (7, 2)。这样,该节点的分割超平面就是通过(7, 2)并垂直于split = 0(x轴)的直线x = 7;

  3. 确定左子空间和右子空间。分割超平面x = 7将整个空间分为两部分,如下图所示。x < = 7的部分为左子空间,包含3个节点{(2, 3), (5, 4), (4, 7)};另一部分为右子空间,包含2个节点{(9, 6), (8, 1)}。

    kd-construct-step1

如算法所述,k-d树的构建是一个递归的过程。然后对左子空间和右子空间内的数据重复根节点的过程就可以得到下一级子节点(5,4)和(9,6)(也就是左右子空间的’根’节点),同时将空间和数据集进一步细分。如此反复直到空间中只包含一个数据点,如图1所示。最后生成的k-d树如下图所示。

kd-construct-step2

注意:每一级节点旁边的’x’和’y’表示以该节点分割左右子空间时split所取的值。

这里进行一点补充说明,kd树其实就是二叉树,其与普通的二叉查找树不同之处在于,其每一层根据split的维度进行二叉拆分。具体来说,根据上图,第一层的拆分是根据x,那么其左孩子的x值就小于根结点的x值,右孩子则反之。y值则没有规定(这里出现的左大右小只是纯粹的巧合)。第二层是根据y值来进行split,因此第三层的规律显而易见。

代码实现

运行环境:Windows 10 Pro 64-bit x64-based(Ver. 10.0.14393), Python 3.5.2, Anaconda 4.1.1(64-bit), IPython 5.0.0, Windows CMD,

kdTreeCreate.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from kdTreeNode import *
def createDataSet():
""" Create the test dataset.
Returns:
A numpy array that contains the test data.
"""
dataSet = np.array([[2, 3], [5, 4], [9, 6],
[4, 7], [8, 1], [7, 2]])
return dataSet
def split(dataSet):
""" Split the given dataset.
Returns:
LeftDataSet: A kdTreeNode object.
RightDataSet: A kdTreeNode object.
NodeData: A tuple.
"""
# Ensure the dimension to split
dimenIndex = np.var(dataSet, axis=0).argmax()
partitionDataSet = dataSet[:, dimenIndex]
# print(partitionDataSet)
# Ensure the position to split
partitionDataSetArgSort = partitionDataSet.argsort()
# print(partitionDataSetArgSort)
lenOfPartitionDataSetArgSort = len(partitionDataSetArgSort)
if lenOfPartitionDataSetArgSort % 2 == 0:
posIndex = lenOfPartitionDataSetArgSort // 2
splitIndex = partitionDataSetArgSort[posIndex]
else:
posIndex = (lenOfPartitionDataSetArgSort - 1) // 2
splitIndex = partitionDataSetArgSort[posIndex]
# print(splitIndex)
# Split
nodeData = dataSet[splitIndex]
leftIndeies = partitionDataSetArgSort[:posIndex]
rightIndeies = partitionDataSetArgSort[posIndex + 1:]
leftDataSet = dataSet[leftIndeies]
rightDataSet = dataSet[rightIndeies]
return nodeData, dimenIndex, leftDataSet, rightDataSet
def createKDTree(dataSet):
""" Create the KD tree.
Returns:
A kdTreeNode object.
"""
if len(dataSet) == 0:
return
nodeData, dimenIndex, leftDataSet, rightDataSet = split(dataSet)
# print(nodeData, dimenIndex, leftDataSet, rightDataSet)
node = kdTreeNode(nodeData, dimenIndex)
node.setLeft(createKDTree(leftDataSet))
node.setRight(createKDTree(rightDataSet))
return node
def midTravel(node):
if node is None:
return
midTravel(node.getLeft())
print(node.getData())
midTravel(node.getRight())
if __name__ == "__main__":
dataSet = createDataSet()
node = createKDTree(dataSet)
midTravel(node)

kdTreeNode.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class kdTreeNode(object):
""" Class of k-d tree nodes
"""
def __init__(self, data=None, split=None, left=None, right=None):
self.__data = data
self.__split = split
self.__left = left
self.__right = right
def getData(self):
return self.__data
def setData(self, data):
self.__data = data
def getSplit(self):
return self.__split
def setSplit(self, split):
self.__split = split
def getLeft(self):
return self.__left
def setLeft(self, left):
self.__left = left
def getRight(self):
return self.__right
def setRight(self, right):
self.__right = right

运行结果:

1
2
3
4
5
6
7
8
In [1]: run kdTreeCreate.py
-------------------------------------
[2 3]
[5 4]
[4 7]
[7 2]
[8 1]
[9 6]

时间复杂度:

N个K维数据进行查找操作时时间复杂度为 $t=O(KN^{2})$

下面就要在已经建立好的kd树上进行查找操作。

查找

kd树中进行的查找与普通的查找操作存在较大的差异,其目的是为了找出与查询点距离最近的点。

星号表示要查询的点(2.1, 3.1)。通过二叉搜索,顺着搜索路径很快就能找到最邻近的近似点,也就是叶子节点(2, 3)。而找到的叶子节点并不一定就是最邻近的,最邻近肯定距离查询点更近,应该位于以查询点为圆心且通过叶子节点的圆域内。为了找到真正的最近邻,还需要进行’回溯’操作:算法沿搜索路径反向查找是否有距离查询点更近的数据点。此例中先从(7, 2)点开始进行二叉查找,然后到达(5, 4),最后到达(2, 3),此时搜索路径中的节点为<(7, 2), (5, 4), (2, 3)>,首先以(2, 3)作为当前最近邻点,计算其到查询点(2.1, 3.1)的距离为0.1414,然后回溯到其父节点(5, 4),并判断在该父节点的其他子节点空间中是否有距离查询点更近的数据点。以(2.1, 3.1)为圆心,以0.1414为半径画圆,如图4所示。发现该圆并不和超平面y = 4交割,因此不用进入(5, 4)节点右子空间中去搜索。

再回溯到(7, 2),以(2.1, 3.1)为圆心,以0.1414为半径的圆更不会与x = 7超平面交割,因此不用进入(7, 2)右子空间进行查找。至此,搜索路径中的节点已经全部回溯完,结束整个搜索,返回最近邻点(2, 3),最近距离为0.1414。

kd-tree-search-1

一个复杂点了例子如查找点为(2, 4.5)。同样先进行二叉查找,先从(7, 2)查找到(5, 4)节点,在进行查找时是由y = 4为分割超平面的,由于查找点为y值为4.5,因此进入右子空间查找到(4, 7),形成搜索路径<(7, 2), (5, 4), (4, 7)>,取(4, 7)为当前最近邻点,计算其与目标查找点的距离为3.202。然后回溯到(5, 4),计算其与查找点之间的距离为3.041。以(2, 4.5)为圆心,以3.041为半径作圆。

kd-tree-search-2

可见该圆和y = 4超平面交割,所以需要进入(5, 4)左子空间进行查找。此时需将(2, 3)节点加入搜索路径中得<(7, 2), (2, 3)>。回溯至(2, 3)叶子节点,(2, 3)距离(2, 4.5)比(5, 4)要近,所以最近邻点更新为(2, 3),最近距离更新为1.5。回溯至(7, 2),以(2, 4.5)为圆心1.5为半径作圆,并不和x = 7分割超平面交割。至此,搜索路径回溯完。返回最近邻点(2, 3),最近距离1.5。

kd-tree-search-2

k-d树查询算法的伪代码如下所示。

查找伪代码
算法: k-d树最邻近查找
输入:Kd, //k-d tree类型
target //查询数据点
输出:nearest, //最邻近数据点
dist //最邻近数据点和查询点间的距离
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
1. If Kd为NULL,则设dist为infinite并返回
2. //进行二叉查找,生成搜索路径
Kd_point = &Kd; //Kd-point中保存k-d tree根节点地址
nearest = Kd_point -> Node-data; //初始化最近邻点
while(Kd_point)
  push(Kd_point)到search_path中; //search_path是一个堆栈结构,存储着搜索路径节点指针
/*** If Dist(nearest,target) > Dist(Kd_point -> Node-data,target)
    nearest = Kd_point -> Node-data; //更新最近邻点
    Max_dist = Dist(Kd_point,target); //更新最近邻点与查询点间的距离 ***/
  s = Kd_point -> split; //确定待分割的方向
  If target[s] <= Kd_point -> Node-data[s] //进行二叉查找
    Kd_point = Kd_point -> left;
  else
    Kd_point = Kd_point ->right;
nearest = search_path中最后一个叶子节点; //注意:二叉搜索时不比计算选择搜索路径中的最邻近点,这部分已被注释
Max_dist = Dist(nearest,target); //直接取最后叶子节点作为回溯前的初始最近邻点
3. //回溯查找
while(search_path != NULL)
  back_point = 从search_path取出一个节点指针; //从search_path堆栈弹栈
  s = back_point -> split; //确定分割方向
  If Dist(target[s],back_point -> Node-data[s]) < Max_dist //判断还需进入的子空间
    If target[s] <= back_point -> Node-data[s]
      Kd_point = back_point -> right; //如果target位于左子空间,就应进入右子空间
    else
      Kd_point = back_point -> left; //如果target位于右子空间,就应进入左子空间
    将Kd_point压入search_path堆栈;
  If Dist(nearest,target) > Dist(Kd_Point -> Node-data,target)
    nearest = Kd_point -> Node-data; //更新最近邻点
    Min_dist = Dist(Kd_point -> Node-data,target); //更新最近邻点与查询点间的距离
代码实现

kdTreeSearch.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
def cal_dist(node, target):
""" Calculate the distance between the node
and the target.
Arguments:
node: The kd-tree's one node.
target: Search target.
Returns:
dist: The distance between the two nodes.
"""
node_data = np.array(node)
target_data = np.array(target)
square_dist_vector = (node_data - target_data) ** 2
square_dist = np.sum(square_dist_vector)
dist = square_dist ** 0.5
return dist
def search(root_node, target):
""" Search the nearest node of the target node
in the kd-tree that root node is the root_node
Arguments:
root_node: The kd-tree's root node.
target: Search target.
Returns:
nearest: The nearest node of the target node in the kd-tree.
min_dist: The nearest distance.
"""
if root_node is None:
min_dist = float('inf')
return min_dist
# Two-fork search
kd_point = root_node # Save the root node
nearest = kd_point.getData() # Initial the nearest node
search_path = [] # Initial the search stack
while kd_point:
search_path.append(kd_point)
split_index = kd_point.getSplit() # Ensure the split path
if target[split_index] <= kd_point.getData()[split_index]:
kd_point = kd_point.getLeft()
else:
kd_point = kd_point.getRight()
nearest = search_path.pop().getData()
min_dist = cal_dist(nearest, target)
# Retrospect search
while search_path:
back_point = search_path.pop()
# Ensure the back-split path
back_split_index = back_point.getSplit()
# Judge if needs to enter the subspace
if cal_dist(target[back_split_index],
back_point.getData()[back_split_index]) < min_dist:
# If the target is in the left subspace, then enter the right
if target[back_split_index] <= back_point.getData()[back_split_index]:
kd_point = back_point.getRight()
# Otherwise enter the left
else:
kd_point = back_point.getLeft()
# Add the node to the search path
if kd_point is not None:
search_path.append(kd_point)
if cal_dist(nearest, target) > cal_dist(kd_point.getData(), target):
# Update the nearest node
nearest = kd_point.getData()
# Update the maximum distance
min_dist = cal_dist(kd_point.getData(), target)
return nearest, min_dist

运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
In [1]: run kdTreeCreate.py
-------------------------------------
[2 3]
[5 4]
[4 7]
[7 2]
[8 1]
[9 6]
In [2]: node
-------------------------------------
Out [2]: <kdTreeNode.kdTreeNode at 0x26bff22f160>
In [3]: import kdTreeSearch
In [4]: nearest, min_dist = kdTreeSearch.search(node, [2.1, 3.1])
In [5]: nearest
-------------------------------------
Out [5]: array([2, 3])
In [6]: min_dist
-------------------------------------
Out [6]: 0.14142135623730964
In [7]: nearest, min_dist = kdTreeSearch.search(node, [2, 4.5])
In [8]: nearest
-------------------------------------
Out [8]: array([2, 3])
In [9]: min_dist
-------------------------------------
Out [9]: 1.5

时间复杂度:

N个结点的K维kd树进行查找操作时最坏时间复杂度为 $t_{worst}=O(KN^{1-1/k})$

根据相关研究,当数据维度为K时,只有当数据量N满足 $N>>2^K$ 时,才能达到高效的搜索(K<20,超过20维时可采用ball-tree算法),所以引出了一系列的改进算法(BBF算法,和一系列M树、VP树、MVP树等高维空间索引树),留待后续补充。

采用kd树的k-近邻算法

接下来便是将两者相结合。

[1] 说是训练集其实是不准确的,因为k-近邻算法是一个无参数方法,只存在一个超参k,因此其不存在一个训练的过程

[2] 测试集

[3] N代表训练集的数目

[4] 多数表决

Abracadabra

Decision tree

Posted on 2017-02-15 | | Visitors

决策树(ID3)

决策树的构建

构造决策树时,所需要解决的第一个问题就是,每划分一个分支时,应该根据哪一维特征进行划分。这时候我们需要定义某种指标,然后对每一维特征进行该指标的评估,最后选择指标值最高的特征进行划分。

划分完毕之后,原始数据集就被划分为几个数据子集。如果某一个下的数据属于同一类型,则算法停止;否则,重复划分过程。

伪代码(创建分支)

1
2
3
4
5
6
7
8
9
10
11
createbranch()
检测数据集中的每个子项是否属于同一分类:
If so return 类标签;
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch并增加返回结果到分支节点中
return 分支节点

那么接下来的重点便是如何寻找划分数据集的最好特征,在这里我们使用ID3算法中使用的划分数据集的方法,也即根据熵来划分。

信息增益

熵

划分数据的核心思想是将无序的数据变得更加有序。而一个数据有序程度可以进行量化表示,也就是信息,其度量方式就是熵。其显然,数据集划分前后其所含的信息会发生变化,这个变化便称为信息增益。

熵定义为信息的期望,其中信息的定义如下,信息一般针对的对象为多个类别中的某一个类别:

$$l(x_i) = -log_{2}p(x_i)$$

其中$x_i$表示某一类别,$p(x_i)$表示从多个类别中选择该类别的概率。

接下来,熵的定义如下:

$$H = \sum_{i=1}^{n}p(x_i)l(x_i)=-\sum_{i=1}^{n}p(x_i)log_{2}p(x_i)$$

信息增益定义如下:

$$IG(S|T) = H(S) - \sum_{value(T)} \frac{|S_v|}{|S|} H(S_v)$$

其中$S$ 为全部样本集合,$value(T) $是属性 $T$所有取值的集合,$v$ 是 $T$ 的其中一个属性值,$S_v$是 $S$ 中属性 $T$ 的值为 $v$ 的样例集合,$|S_v|$ 为 $S_v$ 中所含样例数,$|S|$ 为 $S$ 中所含样例数。

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from math import log
def CalcShannonEnt(data_set):
""" Calculate the Shannon Entropy.
Arguments:
data_set: The object dataset.
Returns:
shannon_ent: The Shannon entropy of the object data set.
"""
# Initiation
num_entries = len(data_set)
label_counts = {}
# Statistics the frequency of each class in the dataset
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
# Calculates the Shannon entropy
shannon_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
shannon_ent -= prob * log(prob, 2)
return shannon_ent

为了进行测试,以及之后的算法运行,我们写一个十分naive的数据生成方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def CreateDataSet():
""" A naive data generation method.
Returns:
data_set: The data set excepts label info.
labels: The data set only contains label info.
"""
data_set = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return data_set, labels

注意,这里的labels并不代表分类标签,yes以及no才是,labels代表特征名。

下面进行一个简单的demo:

1
2
3
4
5
6
7
8
9
In [22]: import trees
In [23]: my_dat, labels = trees.CreateDataSet()
In [24]: my_dat
Out[24]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
In [25]: trees.CalcShannonEnt(my_dat)
Out[25]: 0.9709505944546686

熵越高,表明数据集中类别数越多。

另一个度量无序程度的方法是基尼不纯度(Gini impurity)。

基尼不纯度

基尼不纯度的定义为,对于每一个节点,从所有类别标签中随机选择一个,选择出来的类别标签与其本身的类别标签不一致的概率之和。形式化地定义如下:

$$G = \sum_{i \ne j}p(x_i)p(x_j) = \sum_{i}p(x_i)\sum_{j \ne i}p(x_j) = \sum_{i}p(x_i)(1-p(x_i)) = \sum_{i}p(x_i) - \sum_{i}(p(x_i))^2 = 1 - \sum_{i}(p(x_i))^2$$

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def CalcGiniImpurity(data_set):
""" Calculate the Gini impurity.
Arguments:
data_set: The object dataset.
Returns:
gini_impurity: The Gini impurity of the object data set.
"""
# Initiation
num_entries = len(data_set)
label_counts = {}
# Statistics the frequency of each class in the dataset
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
# Calculates the Gini impurity
gini_impurity = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
gini_impurity += pow(prob, 2)
gini_impurity = 1 - gini_impurity
return gini_impurity

同样进行一个简单的demo:

1
2
3
4
5
6
7
8
9
In [4]: import trees
In [5]: my_dat, labels = trees.CreateDataSet()
In [6]: my_dat
Out[6]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
In [7]: trees.CalcGiniImpurity(my_dat)
Out[7]: 0.48

最后再介绍一种度量无序程度的方式,误分类不纯度。

误分类不纯度

定义如下:

$$M = 1 - \max_{i}(p(x_i))$$

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def CalcMisClassifyImpurity(data_set):
""" Calculate the misclassification impurity.
Arguments:
data_set: The object dataset.
Returns:
mis_classify_impurity: The misclassification impurity of the object data set.
"""
# Initiation
num_entries = len(data_set)
label_counts = {}
# Statistics the frequency of each class in the dataset
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
# Calculates the misclassification impurity
mis_classify_impurity = 0.0
max_prob = max(label_counts.values()) / num_entries
mis_classify_impurity = 1 - max_prob
return mis_classify_impurity

进行一个简单的demo:

1
2
3
4
5
6
7
8
9
10
In [25]: reload(trees)
Out[25]: <module 'trees' from 'C:\\Users\\Ewan\\Documents\\GitHub\\hexo\\public\\2017\\02\\15\\Decision-tree\\trees.py'>
In [26]: my_dat, labels = trees.CreateDataSet()
In [27]: my_dat
Out[27]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
In [28]: trees.CalcMisClassifyImpurity(my_dat)
Out[28]: 0.4

最后用一个图来总结一下这三种不纯度度量的函数图像(以二类情况为例)

[Ref: http://www.cse.msu.edu/~cse802/DecisionTrees.pdf]:

impurity_compare

数据划分

根据以上,数据划分的思路是,基于每一维特征的每一个值进行划分,并计算划分前后的信息增益,最后选取增益最大的特征及其所对应的值进行划分,由于这里运用的是ID3算法,因此选择的信息度量方式是熵。

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def SplitDataSet(data_set, axis, value):
""" Split the data set according to the given axis and correspond value.
Arguments:
data_set: Object data set.
axis: The split-feature index.
value: The value of the split-feature.
Returns:
ret_data_set: The splited data set.
"""
ret_data_set = []
for feat_vec in data_set:
if feat_vec[axis] == value:
reduced_feat_vec = feat_vec[:axis]
reduced_feat_vec.extend(feat_vec[axis + 1:])
ret_data_set.append(reduced_feat_vec)
return ret_data_set
def ChooseBestFeatureToSplit(data_set):
""" Choose the best feature to split.
Arguments:
data_set: Object data set.
Returns:
best_feature: The index of the feature to split.
"""
# Initiation
# Because the range() method excepts the lastest number
num_features = len(data_set[0]) - 1
base_entropy = CalcShannonEnt(data_set)
best_info_gain = 0.0
best_feature = -1
for i in range(num_features):
# Choose the i-th feature of all data
feat_list = [example[i] for example in data_set]
# Abandon the repeat feature value(s)
unique_vals = set(feat_list)
new_entropy = 0.0
# Calculates the Shannon entropy of the splited data set
for value in unique_vals:
sub_data_set = SplitDataSet(data_set, i, value)
prob = len(sub_data_set) / len(data_set)
new_entropy += prob * CalcShannonEnt(sub_data_set)
# base_entropy is equal or greatter than new_entropy
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature

由以上代码(ID3算法)可以看出,其计算熵的依据是根据最后一个特征,是否这种naive的选取方式能够达到平均的最好结果?另外,其划分依据仅仅根据划分一次后的子数据集的熵之和,属于一种贪心策略,这样是否可以达到最优解?

递归构建决策树

既然是递归算法,那么必须设定递归结束条件:

  1. 遍历完所有属性
  2. 每个分支下的数据都属于相同的分类

这里存在一个问题,如果遍历完所有属性后,某些分支下还是存在多个分类,这种情况下一般采用多数表决的方式,代码实现方式如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def majority_cnt(class_list):
""" Decided the final class.
When the splited data is not belongs to the same class
while all feature is handled, the final class is decided by
the majority class.
Arguments:
class_list: The class list of the splited data set.
Returns:
sorted_class_count[0][0]: The majority class.
"""
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(
class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]

下面进行树的创建:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def CreateTree(data_set, labels):
""" Create decision tree.
Arguments:
data_set: The object data set.
labels: The feature labels in the data_set.
Returns:
my_tree: A dict that represents the decision tree.
"""
class_list = [example[-1] for example in data_set]
# If the classes are fully same
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
# If all feature is handled
if len(data_set[0]) == 1:
return majority_cnt(class_list)
# Get the best split-feature and the correspond label
best_feat = ChooseBestFeatureToSplit(data_set)
best_feat_label = labels[best_feat]
# Build a recurrence dict
my_tree = {best_feat_label: {}}
# Get the next step labels parameter
del(labels[best_feat])
# Next step start
feat_values = [example[best_feat] for example in data_set]
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = labels[:]
# Recurrence calls
my_tree[best_feat_label][value] = CreateTree(
SplitDataSet(data_set, best_feat, value), sub_labels)
return my_tree

下面进行一下简单的测试:

1
2
3
4
5
6
7
8
9
In [27]: reload(trees)
Out[27]: <module 'trees' from 'C:\\Users\\Ewan\\Documents\\GitHub\\hexo\\public\\2017\\02\\15\\Decision-tree\\trees.py'>
In [28]: my_dat, labels = trees.CreateDataSet()
In [29]: my_tree = trees.CreateTree(my_dat, labels)
In [30]: my_tree
Out[30]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

可见决策树已经构造成功(图形化如下所示),但是这显然不够,我们需要的是用决策树进行分类。

splitting_path

决策树分类

demo如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
In [63]: reload(trees)
Out[63]: <module 'trees' from 'C:\\Users\\Ewan\\Documents\\GitHub\\hexo\\public\\2017\\02\\15\\Decision-tree\\trees.py'>
In [64]: my_dat, labels = trees.CreateDataSet()
In [65]: labels
Out[65]: ['no surfacing', 'flippers']
In [66]: my_tree = trees.CreateTree(my_dat, labels)
In [67]: labels
Out[67]: ['no surfacing', 'flippers']
In [68]: my_tree
Out[68]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
In [69]: trees.Classify(my_tree, labels, [1, 0])
Out[69]: 'no'
In [70]: trees.Classify(my_tree, labels, [1, 1])
Out[70]: 'yes'

C4.5

C4.5算法是由ID3算法引申而来,主要改进有以下两点:

  1. 选取最优分裂属性时根据信息增益率 (IGR)
  2. 使算法对连续变量兼容

下面分别对分裂信息以及信息增益率进行定义:

$$IGR = \frac{IG}{IV}$$

因此只需对ID3算法的代码做一些改动即可,为了兼容ID3, 具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def ChooseBestFeatureToSplit(data_set, flag='ID3'):
""" Choose the best feature to split.
Arguments:
data_set: Object data set.
flag: Decide if use the infomation gain rate or not.
Returns:
best_feature: The index of the feature to split.
"""
# Initiation
# Because the range() method excepts the lastest number
num_features = len(data_set[0]) - 1
base_entropy = CalcShannon(data_set)
method = 'ID3'
best_feature = -1
best_info_gain = 0.0
best_info_gain_rate = 0.0
for i in range(num_features):
new_entropy = 0.0
# Choose the i-th feature of all data
feat_list = [example[i] for example in data_set]
# Abandon the repeat feature value(s)
unique_vals = set(feat_list)
if len(unique_vals) > 3:
method = 'C4.5'
if method == 'ID3':
# Calculates the Shannon entropy of the splited data set
for value in unique_vals:
sub_data_set = SplitDataSet(data_set, i, value)
prob = len(sub_data_set) / len(data_set)
new_entropy += prob * CalcShannon(sub_data_set)
else:
data_set = np.array(data_set)
sorted_feat = np.argsort(feat_list)
for index in range(len(sorted_feat) - 1):
pre_sorted_feat, post_sorted_feat = np.split(
sorted_feat, [index + 1, ])
pre_data_set = data_set[pre_sorted_feat]
post_data_set = data_set[post_sorted_feat]
pre_coff = len(pre_sorted_feat) / len(sorted_feat)
post_coff = len(post_sorted_feat) / len(sorted_feat)
# Calucate the split info
iv = pre_coff * CalcShannon(pre_data_set) + \
post_coff * CalcShannon(post_data_set)
if iv > new_entropy:
new_entropy = iv
# base_entropy is equal or greatter than new_entropy
info_gain = base_entropy - new_entropy
if flag == 'C4.5':
info_gain_rate = info_gain / new_entropy
# print('index', i, 'info_gain_rate', info_gain_rate)
if info_gain_rate > best_info_gain_rate:
best_info_gain_rate = info_gain_rate
best_feature = i
if flag == 'ID3':
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature

下面需要解决的问题是连续变量的问题,为了实验的方便,我们更改一下naive的数据生成方法(Ref: http://blog.csdn.net/lemon_tree12138/article/details/51840361 ):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def CreateDataSet(method='ID3'):
""" A naive data generation method.
Arguments:
method: The algorithm class
Returns:
data_set: The data set excepts label info.
labels: The data set only contains label info.
"""
# Arguments check
if method not in ('ID3', 'C4.5'):
raise ValueError('invalid value: %s' % method)
if method == 'ID3':
data_set = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
else:
data_set = [[85, 85, 'no'],
[80, 90, 'yes'],
[83, 78, 'no'],
[70, 96, 'no'],
[68, 80, 'no'],
[65, 70, 'yes'],
[64, 65, 'yes'],
[72, 95, 'no'],
[69, 70, 'no'],
[75, 80, 'no'],
[75, 70, 'yes'],
[72, 90, 'yes'],
[81, 75, 'no'],
[71, 80, 'yes']]
labels = ['temperature', 'humidity']
return data_set, labels

假设我们选择了温度属性,则被提取的关键数据为:

[[85, No], [80, No], [83, Yes], [70, Yes], [68, Yes], [65, No], [64, Yes], [72, No], [69, Yes], [75, Yes], [75, Yes], [72, Yes], [81, Yes], [71, No]]

现在我们对这批数据进行从小到大进行排序,排序后数据集就变成:

[[64, Yes], [65, No], [68, Yes], [69, Yes], [70, Yes], [71, No], [72, No], [72, Yes], [75, Yes], [75, Yes], [80, No], [81, Yes], [83, Yes], [85, No]]

绘制成如下图例:

c4.5_data_sorted

当我们拿到一个已经排好序的(温度,结果)的列表之后,分别计算被某个单元分隔的左边和右边的分裂信息,比如现在计算 index = 4 时的分裂信息。则:

$$IV(v_4) = IV([4, 1], [5, 4]) = \frac{5}{14}IV([4, 1]) + \frac{9}{14}IV([5, 4])$$

$$IV(v_4) = \frac{5}{14}(-\frac{4}{5} \log_{2} \frac{4}{5} - \frac{1}{5} \log_{2} \frac{1}{5}) + \frac{9}{14}(-\frac{5}{9} \log_{2} \frac{5}{9} - \frac{4}{9} \log_{2} \frac{4}{9})$$

下图表示了不同分裂位置所得到的分裂信息:

c4_5_data_split

最后给出完整的代码实现 (最后的Classify方法还需修改):

trees.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
from math import log
import operator
import numpy as np
def CalcShannon(data_set):
""" Calculate the Shan0n Entropy.
Arguments:
data_set: The object dataset.
Returns:
shan0n_ent: The Shan0n entropy of the object data set.
"""
# Initiation
num_entries = len(data_set)
label_counts = {}
# Statistics the frequency of each class in the dataset
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
# print(label_counts)
# Calculates the Shan0n entropy
shan0n_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
shan0n_ent -= prob * log(prob, 2)
return shan0n_ent
def CalcGiniImpurity(data_set):
""" Calculate the Gini impurity.
Arguments:
data_set: The object dataset.
Returns:
gini_impurity: The Gini impurity of the object data set.
"""
# Initiation
num_entries = len(data_set)
label_counts = {}
# Statistics the frequency of each class in the dataset
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
# Calculates the Gini impurity
gini_impurity = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
gini_impurity += pow(prob, 2)
gini_impurity = 1 - gini_impurity
return gini_impurity
def CalcMisClassifyImpurity(data_set):
""" Calculate the misclassification impurity.
Arguments:
data_set: The object dataset.
Returns:
mis_classify_impurity:
The misclassification impurity of the object data set.
"""
# Initiation
num_entries = len(data_set)
label_counts = {}
# Statistics the frequency of each class in the dataset
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
# Calculates the misclassification impurity
mis_classify_impurity = 0.0
max_prob = max(label_counts.values()) / num_entries
mis_classify_impurity = 1 - max_prob
return mis_classify_impurity
def CreateDataSet(method='ID3'):
""" A naive data generation method.
Arguments:
method: The algorithm class
Returns:
data_set: The data set excepts label info.
labels: The data set only contains label info.
"""
# Arguments check
if method not in ('ID3', 'C4.5'):
raise ValueError('invalid value: %s' % method)
if method == 'ID3':
data_set = [[1, 1, 1],
[1, 1, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0]]
labels = ['0 surfacing', 'flippers']
else:
data_set = [[1, 85, 85, 0, 0],
[1, 80, 90, 1, 0],
[2, 83, 78, 0, 1],
[3, 70, 96, 0, 1],
[3, 68, 80, 0, 1],
[3, 65, 70, 1, 0],
[2, 64, 65, 1, 1],
[1, 72, 95, 0, 0],
[1, 69, 70, 0, 1],
[3, 75, 80, 0, 1],
[1, 75, 70, 1, 1],
[2, 72, 90, 1, 1],
[2, 81, 75, 0, 1],
[3, 71, 80, 1, 0]]
labels = ['outlook', 'temperature', 'humidity', 'windy']
return data_set, labels
def SplitDataSet(data_set, axis, value):
""" Split the data set according to the given axis and correspond value.
Arguments:
data_set: Object data set.
axis: The split-feature index.
value: The value of the split-feature.
Returns:
ret_data_set: The splited data set.
"""
ret_data_set = []
for feat_vec in data_set:
if feat_vec[axis] == value:
reduced_feat_vec = feat_vec[:axis]
reduced_feat_vec.extend(feat_vec[axis + 1:])
ret_data_set.append(reduced_feat_vec)
return ret_data_set
def ChooseBestFeatureToSplit(data_set, flag='ID3'):
""" Choose the best feature to split.
Arguments:
data_set: Object data set.
flag: Decide if use the infomation gain rate or not.
Returns:
best_feature: The index of the feature to split.
"""
# Initiation
# Because the range() method excepts the lastest number
num_features = len(data_set[0]) - 1
base_entropy = CalcShannon(data_set)
method = 'ID3'
best_feature = -1
best_info_gain = 0.0
best_info_gain_rate = 0.0
for i in range(num_features):
new_entropy = 0.0
# Choose the i-th feature of all data
feat_list = [example[i] for example in data_set]
# Abandon the repeat feature value(s)
unique_vals = set(feat_list)
if len(unique_vals) > 3:
method = 'C4.5'
if method == 'ID3':
# Calculates the Shannon entropy of the splited data set
for value in unique_vals:
sub_data_set = SplitDataSet(data_set, i, value)
prob = len(sub_data_set) / len(data_set)
new_entropy += prob * CalcShannon(sub_data_set)
else:
data_set = np.array(data_set)
sorted_feat = np.argsort(feat_list)
for index in range(len(sorted_feat) - 1):
pre_sorted_feat, post_sorted_feat = np.split(
sorted_feat, [index + 1, ])
pre_data_set = data_set[pre_sorted_feat]
post_data_set = data_set[post_sorted_feat]
pre_coff = len(pre_sorted_feat) / len(sorted_feat)
post_coff = len(post_sorted_feat) / len(sorted_feat)
# Calucate the split info
iv = pre_coff * CalcShannon(pre_data_set) + \
post_coff * CalcShannon(post_data_set)
if iv > new_entropy:
new_entropy = iv
# base_entropy is equal or greatter than new_entropy
info_gain = base_entropy - new_entropy
if flag == 'C4.5':
info_gain_rate = info_gain / new_entropy
# print('index', i, 'info_gain_rate', info_gain_rate)
if info_gain_rate > best_info_gain_rate:
best_info_gain_rate = info_gain_rate
best_feature = i
if flag == 'ID3':
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
def majority_cnt(class_list):
""" Decided the final class.
When the splited data is 0t belongs to the same class
while all feature is handled, the final class is decided by
the majority class.
Arguments:
class_list: The class list of the splited data set.
Returns:
sorted_class_count[0][0]: The majority class.
"""
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(
class_count.
items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
def CreateTree(data_set, feat_labels, method='ID3'):
""" Create decision tree.
Arguments:
data_set: The object data set.
labels: The feature labels in the data_set.
method: The algorithm class.
Returns:
my_tree: A dict that represents the decision tree.
"""
# Arguments check
if method not in ('ID3', 'C4.5'):
raise ValueError('invalid value: %s' % method)
labels = feat_labels.copy()
class_list = [example[-1] for example in data_set]
# print(class_list)
# If the classes are fully same
print('class_list', class_list)
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
# If all feature is handled
if len(data_set[0]) == 1:
return majority_cnt(class_list)
if method == 'ID3':
# Get the best split-feature and the correspond label
best_feat = ChooseBestFeatureToSplit(data_set)
best_feat_label = labels[best_feat]
# print(best_feat_label)
# Build a recurrence dict
my_tree = {best_feat_label: {}}
# Next step start
feat_values = [example[best_feat] for example in data_set]
# Get the next step labels parameter
del(labels[best_feat])
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = labels[:]
# Recurrence calls
my_tree[best_feat_label][value] = CreateTree(
SplitDataSet(data_set, best_feat, value), sub_labels)
return my_tree
else:
flag = 'ID3'
# Get the best split-feature and the correspond label
best_feat = ChooseBestFeatureToSplit(data_set, 'C4.5')
best_feat_label = labels[best_feat]
print(best_feat_label)
# Build a recurrence dict
my_tree = {best_feat_label: {}}
# Next step start
feat_values = [example[best_feat] for example in data_set]
del(labels[best_feat])
unique_vals = set(feat_values)
if len(unique_vals) > 3:
flag = 'C4.5'
if flag == 'ID3':
for value in unique_vals:
sub_labels = labels[:]
# Recurrence calls
my_tree[best_feat_label][value] = CreateTree(
SplitDataSet(data_set, best_feat, value),
sub_labels, 'C4.5')
return my_tree
else:
data_set = np.array(data_set)
best_iv = 0.0
best_split_value = -1
sorted_feat = np.argsort(feat_values)
for i in range(len(sorted_feat) - 1):
pre_sorted_feat, post_sorted_feat = np.split(
sorted_feat, [i + 1, ])
pre_data_set = data_set[pre_sorted_feat]
post_data_set = data_set[post_sorted_feat]
pre_coff = len(pre_sorted_feat) / len(sorted_feat)
post_coff = len(post_sorted_feat) / len(sorted_feat)
# Calucate the split info
iv = pre_coff * CalcShannon(pre_data_set) + \
post_coff * CalcShannon(post_data_set)
if iv > best_iv:
best_iv = iv
best_split_value = feat_values[sorted_feat[i]]
print(best_feat, best_split_value)
# print(best_split_value)
left_data_set = data_set[
data_set[:, best_feat] <= best_split_value]
left_data_set = np.delete(left_data_set, best_feat, axis=1)
# if len(left_data_set) == 1:
# return left_data_set[0][-1]
right_data_set = data_set[
data_set[:, best_feat] > best_split_value]
right_data_set = np.delete(right_data_set, best_feat, axis=1)
# if len(right_data_set) == 1:
# return right_data_set[0][-1]
sub_labels = labels[:]
my_tree[best_feat_label][
'<=' + str(best_split_value)] = CreateTree(
left_data_set.tolist(), sub_labels, 'C4.5')
my_tree[best_feat_label][
'>' + str(best_split_value)] = CreateTree(
right_data_set.tolist(), sub_labels, 'C4.5')
# print('continious tree', my_tree)
return my_tree
def Classify(input_tree, feat_labels, test_vec):
""" Classify that uses the given decision tree.
Arguments:
input_tree: The Given decision tree.
feat_labels: The labels of correspond feature.
test_vec: The test data.
Returns:
class_label: The class label that corresponds to the test data.
"""
# Get the start feature label to split
first_str = list(input_tree.keys())[0]
# Get the sub-tree that corresponds to the start feature to split
second_dict = input_tree[first_str]
# Get the feature index that the label is the start feature label
feat_index = feat_labels.index(first_str)
# Start recurrence search
for key in second_dict.keys():
if test_vec[feat_index] == key:
if type(second_dict[key]).__name__ == 'dict':
# Recurrence calls
class_label = Classify(second_dict[key], feat_labels, test_vec)
else:
class_label = second_dict[key]
return class_label

一个小demo:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
In [108]: reload(trees)
Out[108]: <module 'trees' from 'C:\\Users\\Ewan\\Documents\\GitHub\\hexo\\public\\2017\\02\\15\\Decision-tree\\trees.py'>
In [109]: my_dat, labels = trees.CreateDataSet('C4.5')
In [110]: my_tree_c = trees.CreateTree(my_dat, labels, 'C4.5')
class_list [0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]
outlook
class_list [0, 0, 0, 1, 1]
humidity
1 90
class_list [0, 0, 1, 1]
temperature
0 69
class_list [1]
class_list [0, 0, 1]
windy
class_list [0]
class_list [0, 1]
class_list [0]
class_list [1, 1, 1, 1]
class_list [1, 1, 0, 1, 0]
windy
class_list [1, 1, 1]
class_list [0, 0]
In [111]: my_tree_c
Out[111]:
{'outlook': {1: {'humidity': {'<=90': {'temperature': {'<=69': 1,
'>69': {'windy': {0: 0, 1: 0}}}},
'>90': 0}},
2: 1,
3: {'windy': {0: 1, 1: 0}}}}
1…252627
Ewan Li

Ewan Li

Ewan's IT Blog

131 posts
64 tags
RSS
Github Twitter
© 2019 Ewan Li
Powered by Hexo
Theme - NexT.Mist
本站访客数人次 本站总访问量次