• 代码示例

    代码示例

    让我们使用加仑公里数这个数据集,格式如下:

    代码示例 - 图1

    我会通过汽车的以下属性来判断它的加仑公里数:汽缸数、排气量、马力、重量、加速度。我将392条数据都存放在mpgData.txt文件中,并用下面这段Python代码将这些数据按层次等分成十份:

    1. # -*- coding: utf-8 -*-
    2. #
    3. # 将数据等分成十份的示例代码
    4. import random
    5. def buckets(filename, bucketName, separator, classColumn):
    6. """filename是源文件名
    7. bucketName是十个目标文件的前缀名
    8. separator是分隔符,如制表符、逗号等
    9. classColumn是表示数据所属分类的那一列的序号"""
    10. # 将数据分为10份
    11. numberOfBuckets = 10
    12. data = {}
    13. # 读取数据,并按分类放置
    14. with open(filename) as f:
    15. lines = f.readlines()
    16. for line in lines:
    17. if separator != '\t':
    18. line = line.replace(separator, '\t')
    19. # 获取分类
    20. category = line.split()[classColumn]
    21. data.setdefault(category, [])
    22. data[category].append(line)
    23. # 初始化分桶
    24. buckets = []
    25. for i in range(numberOfBuckets):
    26. buckets.append([])
    27. # 将各个类别的数据均匀地放置到桶中
    28. for k in data.keys():
    29. # 打乱分类顺序
    30. random.shuffle(data[k])
    31. bNum = 0
    32. # 分桶
    33. for item in data[k]:
    34. buckets[bNum].append(item)
    35. bNum = (bNum + 1) % numberOfBuckets
    36. # 写入文件
    37. for bNum in range(numberOfBuckets):
    38. f = open("%s-%02i" % (bucketName, bNum + 1), 'w')
    39. for item in buckets[bNum]:
    40. f.write(item)
    41. f.close()
    42. # 调用示例
    43. buckets("pimaSmall.txt", 'pimaSmall',',',8)

    执行这个程序后会生成10个文件:mpgData01、mpgData02等。

    编程实践

    你能否修改上一章的近邻算法程序,让test函数能够执行十折交叉验证?输出的结果应该是这样的:

    代码示例 - 图2

    解决方案

    我们需要进行以下几步:

    • 修改初始化方法,只读取九个桶中的数据作为训练集;
    • 增加一个方法,从第十个桶中读取测试集;
    • 执行十折交叉验证。

    下面我们分步来看:

    • 初始化方法__init__

    __init__方法的签名会修改成以下形式:

    1. def __init__(self, bucketPrefix, testBucketNumber, dataFormat):

    每个桶的文件名是mpgData-01、mpgData-02这样的形式,所以bucketPrefix就是“mpgData”。testBucketNumber是测试集所用的桶,如果是3,则分类器会使用1、2、4-9的桶进行训练。dataFormat用来指定数据集的格式,如:

    1. class num num num num num comment

    意味着第一列是所属分类,后五列是特征值,最后一列是备注信息。

    以下是初始化方法的示例代码:

    1. class Classifier:
    2. def __init__(self, bucketPrefix, testBucketNumber, dataFormat):
    3. """该分类器程序将从bucketPrefix指定的一系列文件中读取数据,
    4. 并留出testBucketNumber指定的桶来做测试集,其余的做训练集。
    5. dataFormat用来表示数据的格式,如:
    6. "class num num num num num comment"
    7. """
    8. self.medianAndDeviation = []
    9. # 从文件中读取文件
    10. self.format = dataFormat.strip().split('\t')
    11. self.data = []
    12. # 用1-10来标记桶
    13. for i in range(1, 11):
    14. # 判断该桶是否包含在训练集中
    15. if i != testBucketNumber:
    16. filename = "%s-%02i" % (bucketPrefix, i)
    17. f = open(filename)
    18. lines = f.readlines()
    19. f.close()
    20. for line in lines[1:]:
    21. fields = line.strip().split('\t')
    22. ignore = []
    23. vector = []
    24. for i in range(len(fields)):
    25. if self.format[i] == 'num':
    26. vector.append(float(fields[i]))
    27. elif self.format[i] == 'comment':
    28. ignore.append(fields[i])
    29. elif self.format[i] == 'class':
    30. classification = fields[i]
    31. self.data.append((classification, vector, ignore))
    32. self.rawData = list(self.data)
    33. # 获取特征向量的长度
    34. self.vlen = len(self.data[0][1])
    35. # 标准化数据
    36. for i in range(self.vlen):
    37. self.normalizeColumn(i)
    • testBucket方法

    下面的方法会使用一个桶的数据进行测试:

    1. def testBucket(self, bucketPrefix, bucketNumber):
    2. """读取bucketPrefix-bucketNumber所指定的文件作为测试集"""
    3. filename = "%s-%02i" % (bucketPrefix, bucketNumber)
    4. f = open(filename)
    5. lines = f.readlines()
    6. totals = {}
    7. f.close()
    8. for line in lines:
    9. data = line.strip().split('\t')
    10. vector = []
    11. classInColumn = -1
    12. for i in range(len(self.format)):
    13. if self.format[i] == 'num':
    14. vector.append(float(data[i]))
    15. elif self.format[i] == 'class':
    16. classInColumn = i
    17. theRealClass = data[classInColumn]
    18. classifiedAs = self.classify(vector)
    19. totals.setdefault(theRealClass, {})
    20. totals[theRealClass].setdefault(classifiedAs, 0)
    21. totals[theRealClass][classifiedAs] += 1
    22. return totals

    比如说bucketPrefix是mpgData,bucketNumber是3,那么程序会从mpgData-03中读取内容,作为测试集。这个方法会返回如下形式的结果:

    1. {'35': {'35': 1, '20': 1, '30': 1},
    2. '40': {'30': 1},
    3. '30': {'35': 3, '30': 1, '45': 1, '25': 1},
    4. '15': {'20': 3, '15': 4, '10': 1},
    5. '10': {'15': 1},
    6. '20': {'15': 2, '20': 4, '30': 2, '25': 1},
    7. '25': {'30': 5, '25': 3}}

    这个字段的键表示真实类别。如第一行的35表示该行数据的真实类别是35加仑公里。这个键又对应一个字典,这个字典表示的是分类器所判断的类别,如:

    1. '15': {'20': 3, '15': 4, '10': 1},

    其中的3表示有3条记录真实类别是15加仑公里,但被分类到了20加仑公里;4表示分类正确的记录数;1表示被分到10加仑公里的记录数。

    • 执行十折交叉验证

    最后我们需要编写一段程序来执行十折交叉验证,也就是说要用不同的训练集和测试集来构建10个分类器。

    1. def tenfold(bucketPrefix, dataFormat):
    2. results = {}
    3. for i in range(1, 11):
    4. c = Classifier(bucketPrefix, i, dataFormat)
    5. t = c.testBucket(bucketPrefix, i)
    6. for (key, value) in t.items():
    7. results.setdefault(key, {})
    8. for (ckey, cvalue) in value.items():
    9. results[key].setdefault(ckey, 0)
    10. results[key][ckey] += cvalue
    11. # 输出结果
    12. categories = list(results.keys())
    13. categories.sort()
    14. print( "\n Classified as: ")
    15. header = " "
    16. subheader = " +"
    17. for category in categories:
    18. header += category + " "
    19. subheader += "----+"
    20. print (header)
    21. print (subheader)
    22. total = 0.0
    23. correct = 0.0
    24. for category in categories:
    25. row = category + " |"
    26. for c2 in categories:
    27. if c2 in results[category]:
    28. count = results[category][c2]
    29. else:
    30. count = 0
    31. row += " %2i |" % count
    32. total += count
    33. if c2 == category:
    34. correct += count
    35. print(row)
    36. print(subheader)
    37. print("\n%5.3f percent correct" %((correct * 100) / total))
    38. print("total of %i instances" % total)
    39. # 调用方法
    40. tenfold("mpgData/mpgData", "class num num num num num comment")

    执行结果如下:

    1. Classified as:
    2. 10 15 20 25 30 35 40 45
    3. +----+----+----+----+----+----+----+----+
    4. 10 | 3 | 10 | 0 | 0 | 0 | 0 | 0 | 0 |
    5. 15 | 3 | 68 | 14 | 1 | 0 | 0 | 0 | 0 |
    6. 20 | 0 | 14 | 66 | 9 | 5 | 1 | 1 | 0 |
    7. 25 | 0 | 1 | 14 | 35 | 21 | 6 | 1 | 1 |
    8. 30 | 0 | 1 | 3 | 17 | 21 | 14 | 5 | 2 |
    9. 35 | 0 | 0 | 2 | 8 | 9 | 14 | 4 | 1 |
    10. 40 | 0 | 0 | 1 | 0 | 5 | 5 | 0 | 0 |
    11. 45 | 0 | 0 | 0 | 2 | 1 | 1 | 0 | 2 |
    12. +----+----+----+----+----+----+----+----+
    13. 53.316 percent correct
    14. total of 392 instances

    可以在这里下载代码和数据集。