这两天给小组讲决策树,顺便写了个ID3的代码,贴出来,希望有能用上的,也可以一起讨论学习决策树,后面会给出C4.5...
复制内容到剪贴板
代码:
# -*- coding: utf-8 -*-
"""
This is a ID3 module based on Quinlan's ID3 algorithm,that holds
functions that are responsible for creating a new decision tree
and using the tree to classfy the given test samples.
"""
__author__ = "chuter"
__contact__ = "liulong@ir.hit.edu.cn or topgun.chuter@gmail.com"
import sys, math
import pickle
class ID3:
"""
based on Quinlan's ID3 algorithm,load data file in the form as follows:
attrname,attrname,attrname,attrname,attrname...,classes|class,class,...
attrvalue, attrvalue,attrvalue,...,class
attrvalue, attrvalue,attrvalue,...,class
...
then use the data to create a dicition tree, and can classfy the given
test samples use the tree, the sample in the form as follows:
attrvalue,attrvalue,attrvalue,attrvalue...
attrvalue,attrvalue,attrvalue,attrvalue...
"""
def __init__(self):
self.data = []
self.attr_list = []
self.attr_index = {}
def err_exit(self, err_massage):
print >>sys.stderr, "Fatal error in file %s: %s"%(__file__, err_massage)
sys.exit()
def most_frequecy_item(self, original_list):
number_dic = {}
for item in original_list:
if number_dic.has_key(item):
number_dic[item] += 1
else:
number_dic[item] = 1
_max = 0
ret_val = None
for key in number_dic.keys():
if number_dic[key] > _max:
_max = number_dic[key]
ret_val = key
return ret_val
def unique(self, original_list):
unique_dic = {}
for item in original_list:
unique_dic[item] = 1
return unique_dic.keys()
def load_data(self, datafile):
#use ',' to split each attr or attrvalue, but there can be any blanks
#before or after the ','
try:
fileobject = open(datafile, 'r')
except Exception, exception:
self.err_exit(exception)
line = fileobject.readline().strip()
pos = line.find('|', 0)
if pos < 0:
fileobject.close()
self.err_exit(
"you should give one class at least, or make sure" +
"you use '|' to split the attrs and the classes"
)
attr = line[:pos]
classes = line[pos+1:]
self.attr_list = [attr.strip() for attr in attr.split(',')]
attrnum = len(self.attr_list)
self.target_attr = self.attr_list.pop()
self.classes = [cla.strip() for cla in classes.split(',')]
linenum = 2
while True:
line = fileobject.readline().strip()
if len(line) <= 1:
break
attr_value_list = [attr_value.strip() for attr_value in line.split(',')]
if len(attr_value_list) != attrnum:
error_info = "attr value is more or less than the number of" +\
"attrs in the line \"%s\" of the datafile"%linenum
self.err_exit(error_info)
self.data.append(attr_value_list)
linenum += 1
fileobject.close()
def entropy_cacu(self, data):
"""
Calculates the entropy of the given data set for the target attribute
"""
freq_count = {}
data_entropy = 0.0
# Calculate the frequency of each of the values of the target attr
for sample in data:
attr_val = sample[-1]
if (freq_count.has_key(attr_val)):
freq_count[attr_val] += 1.0
else:
freq_count[attr_val] = 1.0
# Calculate the entropy of the data for the target attribute
for freq in freq_count.values():
data_entropy += (-freq/len(data)) * math.log(freq/len(data), 2)
return data_entropy
def gain_cacu(self, data, attr):
"""
Calculates the information gain that would result by splitting
the data on the chosen attribute.
"""
freq_count = {}
pos_index = self.attr_index[attr]
# Calculate the frequency of each of the values of the target attribute
for sample in data:
attr_val = sample[pos_index]
if (freq_count.has_key(attr_val)):
freq_count[attr_val] += 1.0
else:
freq_count[attr_val] = 1.0
# Calculate the sum of the entropy for each subset of records weighted
# by their probability of occuring in the training set.
subset_entropy = 0.0
for attr_val in freq_count.keys():
val_prob = freq_count[attr_val] / sum(freq_count.values())
data_subset = [sample for sample in data if sample[pos_index] == attr_val]
subset_entropy += val_prob * self.entropy_cacu(data_subset)
return (self.entropy_cacu(data) - subset_entropy)
def choose_attribute(self, data, attr_list):
"""
search through all the attributes and returns the attribute with the
highest information gain
"""
best_gain = 0.0
best_attr = None
for attr in attr_list:
gain = self.gain_cacu(data, attr)
if gain >= best_gain:
best_gain = gain
best_attr = attr
return best_attr
def get_examples(self, data, attr, val):
"""
Returns a list of samples with the value of attr
matching the given val.
"""
if not data:
return []
else:
pos_index = self.attr_index[attr]
return [sample for sample in data if sample[pos_index] == val]
def create_decision_tree(self, data, attr_list):
"""
create the dicition tree based on the data have loaded Recursivly
"""
tar_val_list = [sample[-1] for sample in data]
#the Conditions to return
if not data or len(attr_list) <= 0:
return self.most_frequecy_item(tar_val_list)
elif tar_val_list.count(tar_val_list[0]) == len(tar_val_list):
return tar_val_list[0]
else:
best = self.choose_attribute(data, attr_list)
tree = {best:{}}
pos_index = self.attr_index[best]
# Create a new sub decision tree for each of the values of the best
for attr_val in self.unique([sample[pos_index] for sample in data]):
subtree = self.create_decision_tree(
self.get_examples(data, best, attr_val), \
[attr for attr in attr_list if attr != best])
tree[best][attr_val] = subtree
return tree
def train_model(self, datafile=None, modelname=None):
"""
use ID3 algorithm to train a model and save the model use pickle
"""
if not datafile:
datafile = 'data'
self.load_data(datafile)
self.attr_index = {}
pos_index = 0
for attr in self.attr_list:
self.attr_index[attr] = pos_index
pos_index += 1
tree = self.create_decision_tree(self.data, self.attr_list)
if not modelname:
modelname = 'model'
try:
model_fob = open(modelname, 'wb')
tree['attr_list'] = self.attr_list
pickle.dump(tree, model_fob)
model_fob.close()
except Exception, error_info:
self.err_exit(error_info)
return tree
def get_classification(self, sample, tree):
"""
This function recursively trace the decision tree and returns a
classification for the given sample.
"""
# If we reach a leaf node, we return the node value
if type(tree) == str:
return tree
# else we trace the tree deeper until a leaf node is found.
else:
attr = tree.keys()[0]
try:
pos_index = self.attr_index[attr]
t = tree[attr][sample[pos_index]]
except KeyError:
self.err_exit("The test data is error, check the attrs!!")
return self.get_classification(sample, t)
def classify(self, tree, test_data):
if not self.attr_list:
try:
self.attr_list = tree.pop('attr_list')
except KeyError:
self.err_exit("Make sure give the right path of the model")
pos_index = 0
for attr in self.attr_list:
self.attr_index[attr] = pos_index
pos_index += 1
classification_list = []
for sample in test_data:
classification_list.append(self.get_classification(sample, tree))
return classification_list
def load_test_samples(self, filename):
try:
fobject = open(filename, 'rb')
except Exception, error_info:
self.err_exit(error_info)
samples = []
while True:
line = fobject.readline().strip()
if len(line) <= 1:
break
samples.append([attr.strip() for attr in line.split(',')])
return samples
def test(self, testfile=None, modelname=None, tree=None):
"""
use the model(tree) has trained to do the test for the test samples
"""
if not testfile:
self.err_exit("You should give the path of the test file")
if not modelname:
if not tree:
self.err_exit("You should give the model")
else:
modelname = 'model'
try:
model_fob = open(modelname, 'rb')
tree = pickle.load(model_fob)
model_fob.close()
except Exception, error_info:
self.err_exit(error_info)
test_samples = self.load_test_samples(testfile)
return self.classify(tree, test_samples)显示上,代码宽度设置好像小了点儿,显示不太好...
[
本帖最后由 chuter 于 2008-6-5 17:46 编辑 ]