Apriori是最常见的关联分析算法之一,其基本步骤是:
(1)令 k=1,生成所有长度为1的频繁集
重复下列步骤,直到不能确定新的频繁集
(2)根据长度为k的频繁集生成长度为k+1的频繁集
(3)修剪掉存在k长度的子集不是频繁集的候选集
(4)扫描所有事务计算每个候选集的支持度
(5)排除不频繁的候选集,仅保留频繁的

import argparse
from itertools import chain, combinations

class Rule(object):
    '''ASSOCIATION RULES'''
    '''一个规则类,成员包含规则前件A,后件B,规则的支持度,置信度,timie是规则的序号,无意义'''
    def __init__(self, A, B, support, confidence, time):
        self.A = A
        self.B = B
        self.support = support
        self.confidence = confidence
        self.time = time

    def __repr__(self):
        return '%s ==> %-6s\t%.3f\t\t%.3f' % (' '.join(sorted(list(self.A))),
                                        ' '.join(sorted(list(self.B))),
                                        self.confidence,
                                        self.support)

class Apriori(object):
    '''APRIORI'''
    '''初始化,设定数据集,支持度阈值,置信度阈值,同时计算出频繁集'''
    def __init__(self, data, min_support, min_confidence):
        self.data = data
        self.min_support = min_support
        self.min_confidence = min_confidence
        self.itemset, self.transaction_list = self.get_itemset_from_data()
        self.frequent_itemset = self.get_frequent_itemset()

    @staticmethod
    def join_set(itemset, k):
        '''JOINS TWO ITEMSETS TO GET A k LENGTH UNION'''
        '''使用两个k-1阶频繁子集生成k阶候选集,这里简单使用了union方法,效率低,有重复,可优化!!!!!'''
        return set([i.union(j) for i in itemset for j in itemset if len(i.union(j)) == k])

    @staticmethod
    def get_combined_subsets(itemset):
        '''COMBINES ITEMSETS'''
        ''' 这里用于生成频繁集所有可能的子集,构造规则的时候使用,这里也可以剪枝优化!!!!!! '''
        return chain(*[combinations(itemset, index + 1) for index, item in enumerate(itemset)])

    def get_itemset_from_data(self):
        '''EXTRACTS ITEMSET FROM DATABASE'''
        ''' 这里使用set找出了所有的可能项(顺便也构造了1-项集),同时构造除了事务集列表'''
        itemset = set()
        transaction_list = list()
        for row in self.data:
            transaction_list.append(frozenset(row))
            for item in row:
                if item:
                    itemset.add(frozenset([item]))
        return itemset, transaction_list

    def get_support_list(self):
        '''GENERATES SUPPORT LIST HIGHER THAN MINIMUM SUPPORT THRESHOLD'''
        ''' 这里计算当前k-项集的支持度,并使用支持度阈值进行筛选,判断一个事务是否支持项集直接使用了issubset方法,这里也可以进行优化!!!!!'''
        unpruned_list = [(item, float(sum(1 for row in self.transaction_list if item.issubset(row)))/len(self.transaction_list))
                         for item in self.itemset]
        return dict([(item, support) for item, support in unpruned_list if support >= self.min_support])

    
    def get_frequent_itemset(self):
        '''GENERATES FREQUENT ITEMSETS'''
        ''' 计算所有频繁项集 '''
        ''' 这种处理方式没有通过字典序要求合并k-1阶项集来避免生成重复的k阶项集,也没有使用k-1阶频繁集筛选k阶频繁集!!!!!'''
        frequent_itemset = dict()
        k = 1
        while True:
            ''' 当k=1时,直接用get_itemset_from_data中构造的1-项集计算支持度同时用阈值筛选 '''
            ''' 当k>1时,用上次获得的k-1-项集组合生成k-项集,再计算其支持度并用阈值筛选'''
            if k > 1:
                self.itemset = self.join_set(next_itemset, k)
            next_itemset = self.get_support_list()
            ''' 如果所有的候选都被阈值筛选掉了,结束 '''
            if not next_itemset:
                break
            ''' 记录找到的频繁集,前面k-1阶合并得到k阶进行了所有可能的合并,所有有重复,这里使用update方法去掉重复 '''            
            frequent_itemset.update(next_itemset)
            k += 1
        return frequent_itemset

    def run(self):
        '''RUNS APRIORI ALGORITHM'''
        ''' 针对每个频繁集,构建可能规则集并计算置信度进行筛选 '''
        ''' 这里构造了所有可能的规则进行置信度筛选,没有使用反单调性进行剪枝,可优化!!!!!!!!'''
        rules, time = list(), 0
        for item, support in self.frequent_itemset.items(): #对所有的频繁集进行美剧
            if len(item) > 1:
                for A in self.get_combined_subsets(item): #对所有可能子集进行枚举
                    B = item.difference(A) #后件
                    if B:
                        A = frozenset(A)
                        AB = A | B
                        confidence = float(self.frequent_itemset[AB]) / self.frequent_itemset[A]
                        if confidence >= self.min_confidence:
                            rules.append(Rule(A, B, support=self.frequent_itemset[AB], confidence=confidence, time=time))
                            time += 1
        return rules, self.frequent_itemset

def parse_arguments():
    '''PARSES COMMAND LINE ARGUMENTS'''
    argparser = argparse.ArgumentParser(description='Apriori Algorithm.')
    argparser.add_argument(
        '-s', '--min_support',
        dest='min_support',
        help='minimum support',
        default=0.25,
        type=float
    )
    argparser.add_argument(
        '-c', '--min_confidence',
        dest='min_confidence',
        help='minimum confidence',
        default=0.5,
        type=float
    )
    argparser.add_argument(
        dest='filename',
        help='filename containing transactions',
        default='transactions.txt',
    )
    return argparser.parse_args()

def data_from_txt(filename):
    '''EXTRACTS DATABASE FROM .txt FILE'''
    file = open(filename, 'r')
    for line in file:
        row = line.strip().split()        
        yield row

def print_frequent_itemsets(itemset):
    '''PRINTS FREQUENT ITEMSETS'''
    print('========================')
    print('Itemset\t\tSupport')
    print('========================')
    for item in itemset.keys():
        print('%s\t\t%.3f' % (' '.join(sorted(list(item))), itemset[item]))

def print_association_rules(rules):
    '''PRINTS ASSOCIATION RULES'''
    print('========================================')
    print('    Rule\tConfidence\tSupport')
    print('========================================')
    rules.sort(key=lambda x: (len(x.A) + len(x.B), x.confidence, x.support, -x.time), reverse=True)
    for rule in rules:
        print(rule)

def main():
    '''MAIN METHOD'''
    min_support = 0.2
    min_confidence = 0.6
    data = data_from_txt('./transactions.txt')    
    rules, itemset = Apriori(data, min_support, min_confidence).run()
    print('Mined {}\nand found a total of {} association rules:'.format('transactions', len(rules)))
    print_association_rules(rules)
    print_frequent_itemsets(itemset)

if __name__ == '__main__':
    main()