发新话题
打印

[算法学习] ID3代码

ID3代码

这两天给小组讲决策树,顺便写了个ID3的代码,贴出来,希望有能用上的,也可以一起讨论学习决策树,后面会给出C4.5...
复制内容到剪贴板
代码:
# -*- coding: utf-8 -*-

"""
This is a ID3 module based on Quinlan's ID3 algorithm,that holds
functions that are responsible for creating a new decision tree
and using the tree to classfy the given test samples.
"""

__author__ = "chuter"
__contact__ = "liulong@ir.hit.edu.cn or topgun.chuter@gmail.com"

import sys, math
import pickle

class ID3:
   
    """
    based on Quinlan's ID3 algorithm,load data file in the form as follows:
        
    attrname,attrname,attrname,attrname,attrname...,classes|class,class,...
    attrvalue, attrvalue,attrvalue,...,class
    attrvalue, attrvalue,attrvalue,...,class
    ...
   
    then use the data to create a dicition tree, and can classfy the given
    test samples use the tree, the sample in the form as follows:
        
    attrvalue,attrvalue,attrvalue,attrvalue...
    attrvalue,attrvalue,attrvalue,attrvalue...
   
    """
    def __init__(self):
        self.data       = []
        self.attr_list  = []
        self.attr_index = {}
        
    def err_exit(self, err_massage):
        print >>sys.stderr, "Fatal error in file %s: %s"%(__file__, err_massage)
        sys.exit()
   
    def most_frequecy_item(self, original_list):
        number_dic = {}
        for item in original_list:
            if number_dic.has_key(item):
                number_dic[item] += 1
            else:
                number_dic[item] = 1
        _max    = 0
        ret_val = None
        for key in number_dic.keys():
            if number_dic[key] > _max:
                _max    = number_dic[key]
                ret_val = key
        return ret_val
        
    def unique(self, original_list):
        unique_dic = {}
        for item in original_list:
            unique_dic[item] = 1
        return unique_dic.keys()
        
    def load_data(self, datafile):
        #use ',' to split each attr or attrvalue, but there can be any blanks
        #before or after the ','
        try:
            fileobject = open(datafile, 'r')
        except Exception, exception:
            self.err_exit(exception)
        line = fileobject.readline().strip()
        pos  = line.find('|', 0)
        if pos < 0:
            fileobject.close()
            self.err_exit(
                "you should give one class at least, or make sure" +
                "you use '|' to split the attrs and the classes"
                )
        attr             = line[:pos]
        classes          = line[pos+1:]
        self.attr_list   = [attr.strip() for attr in attr.split(',')]
        attrnum          = len(self.attr_list)
        self.target_attr = self.attr_list.pop()
        self.classes     = [cla.strip() for cla in classes.split(',')]
        linenum          = 2
        while True:
            line = fileobject.readline().strip()
            if len(line) <= 1:
                break
            attr_value_list = [attr_value.strip() for attr_value in line.split(',')]
            if len(attr_value_list) != attrnum:
                error_info = "attr value is more or less than the number of" +\
                              "attrs in the line \"%s\" of the datafile"%linenum
                self.err_exit(error_info)
            self.data.append(attr_value_list)
            linenum += 1
        fileobject.close()
   
    def entropy_cacu(self, data):
        """
        Calculates the entropy of the given data set for the target attribute
        """
        freq_count = {}
        data_entropy = 0.0

        # Calculate the frequency of each of the values of the target attr
        for sample in data:
            attr_val = sample[-1]
            if (freq_count.has_key(attr_val)):
                freq_count[attr_val] += 1.0
            else:
                freq_count[attr_val] = 1.0

        # Calculate the entropy of the data for the target attribute
        for freq in freq_count.values():
            data_entropy += (-freq/len(data)) * math.log(freq/len(data), 2)
            
        return data_entropy
        
    def gain_cacu(self, data, attr):
        """
        Calculates the information gain that would result by splitting
        the data on the chosen attribute.
        """
        freq_count     = {}
        pos_index      = self.attr_index[attr]        

        # Calculate the frequency of each of the values of the target attribute
        for sample in data:
            attr_val = sample[pos_index]
            if (freq_count.has_key(attr_val)):
                freq_count[attr_val] += 1.0
            else:
                freq_count[attr_val] = 1.0

        # Calculate the sum of the entropy for each subset of records weighted
        # by their probability of occuring in the training set.
        subset_entropy = 0.0
        for attr_val in freq_count.keys():
            val_prob = freq_count[attr_val] / sum(freq_count.values())
            data_subset = [sample for sample in data if sample[pos_index] == attr_val]
            subset_entropy += val_prob * self.entropy_cacu(data_subset)

        return (self.entropy_cacu(data) - subset_entropy)

    def choose_attribute(self, data, attr_list):
        """
        search through all the attributes and returns the attribute with the
        highest information gain
        """
        best_gain = 0.0
        best_attr = None

        for attr in attr_list:
            gain = self.gain_cacu(data, attr)
            if gain >= best_gain:
                best_gain = gain
                best_attr = attr
                    
        return best_attr

    def get_examples(self, data, attr, val):
        """
        Returns a list of samples with the value of attr
        matching the given val.
        """
        
        if not data:
            return []
        else:
            pos_index = self.attr_index[attr]
            return [sample for sample in data if sample[pos_index] == val]
                        
    def create_decision_tree(self, data, attr_list):
        """
        create the dicition tree based on the data have loaded Recursivly
        """
        
        tar_val_list = [sample[-1] for sample in data]
        #the Conditions to return
        if not data or len(attr_list) <= 0:            
            return self.most_frequecy_item(tar_val_list)
        elif tar_val_list.count(tar_val_list[0]) == len(tar_val_list):
            return tar_val_list[0]
        else:
            best      = self.choose_attribute(data, attr_list)
            tree      = {best:{}}
            pos_index = self.attr_index[best]
            # Create a new sub decision tree for each of the values of the best
            for attr_val in self.unique([sample[pos_index] for sample in data]):
                subtree = self.create_decision_tree(
                        self.get_examples(data, best, attr_val), \
                                [attr for attr in attr_list if attr != best])         
                tree[best][attr_val] = subtree
            
        return tree
   
    def train_model(self, datafile=None, modelname=None):
        """
        use ID3 algorithm to train a model and save the model use pickle
        """
        if not datafile:
            datafile = 'data'
        self.load_data(datafile)
        self.attr_index = {}
        pos_index = 0
        for attr in self.attr_list:
            self.attr_index[attr] = pos_index
            pos_index += 1
        
        tree  = self.create_decision_tree(self.data, self.attr_list)
        if not modelname:
            modelname = 'model'
        try:
            model_fob = open(modelname, 'wb')
            tree['attr_list'] = self.attr_list
            pickle.dump(tree, model_fob)
            model_fob.close()
        except Exception, error_info:
            self.err_exit(error_info)
            
        return tree        

    def get_classification(self, sample, tree):
        """
        This function recursively trace the decision tree and returns a
        classification for the given sample.
        """
        # If we reach a leaf node, we return the node value
        if type(tree) == str:
            return tree        
        # else we trace the tree deeper until a leaf node is found.
        else:
            attr = tree.keys()[0]            
            try:
                pos_index = self.attr_index[attr]
                t = tree[attr][sample[pos_index]]
            except KeyError:
                self.err_exit("The test data is error, check the attrs!!")
            return self.get_classification(sample, t)
        
    def classify(self, tree, test_data):
        if not self.attr_list:
            try:               
                self.attr_list = tree.pop('attr_list')
            except KeyError:
                self.err_exit("Make sure give the right path of the model")
        pos_index = 0
        for attr in self.attr_list:
            self.attr_index[attr] = pos_index
            pos_index += 1
            
        classification_list = []            
        for sample in test_data:
            classification_list.append(self.get_classification(sample, tree))
        
        return classification_list
   
    def load_test_samples(self, filename):
        try:
            fobject = open(filename, 'rb')
        except Exception, error_info:
            self.err_exit(error_info)
            
        samples = []
        while True:
            line = fobject.readline().strip()
            if len(line) <= 1:
                break
            samples.append([attr.strip() for attr in line.split(',')])
            
        return samples
        
    def test(self, testfile=None, modelname=None, tree=None):
        """
        use the model(tree) has trained to do the test for the test samples
        """
        if not testfile:
            self.err_exit("You should give the path of the test file")
        if not modelname:
            if not tree:
                self.err_exit("You should give the model")
            else:
                modelname = 'model'
        try:
            model_fob = open(modelname, 'rb')
            tree = pickle.load(model_fob)
            model_fob.close()
        except Exception, error_info:
            self.err_exit(error_info)
        test_samples = self.load_test_samples(testfile)
        return self.classify(tree, test_samples)
显示上,代码宽度设置好像小了点儿,显示不太好...

[ 本帖最后由 chuter 于 2008-6-5 17:46 编辑 ]
#-----------------------------
Think big, think difference and do your best!

TOP

发新话题