大家好,我是W

在数据挖掘中有一种关联分析算法叫做Apriori算法,大家可能都听说过啤酒尿布的故事,购买尿布的爸爸很可能会再去购买一份啤酒来犒劳自己,在大数据的背景下已经无法使用人工的方法去发现海量商品间的关联性,所以需要算法的支持。Apriori就是关联性分析算法的祖师级算法。接下来我们从下面几个内容来讲Apriori算法:1、相関概念 2、算法原理 3、Apriori算法实现 - 7500行购物清单案例 4、算法优劣分析。

1、 相关概念

在学习算法前需要了解一些特定名词,以及一些评估频繁项集的几个指标。

1.1 事务型数据

关联分析的数据一般是事务型数据,即数据每一行对应一个事务,事务的元素称为项,项集就是项组合成的集合,有1-项集、2-项集、k-项集。

1.2 支持度

一个项集的支持度就是该项集在整个数据集中出现的频率,例如:

数据集:
2,3,5,6
1,3,5
2,4
5,6,7

1-项集:3在数据集中的支持度为 2/4

具体的公式为:

apriori算法 机器学习 apriori算法应用_算法

1.3 置信度

置信度是指一个数据出现的情况下,某一数据出现的概率。或者说在购买A商品的情况下,继续去购买B商品的概率。其公式是:

apriori算法 机器学习 apriori算法应用_python_02

A事件发生的情况下,B事件也发生的概率。

1.4 提升度

提升度是指在A事件的带动下B事件发生的概率与无A事件B事件发生的概率的比值,说人话就是有A和没A对B的影响比值。提升度就等于置信度比上B事件发生的概率。其公式是:

apriori算法 机器学习 apriori算法应用_算法_03

2、算法原理

Apriori算法的目的是通过刚刚提到的几个指标去找到物品之间的规则,如果使用最笨的方法,那么就是遍历出所有的组合,然后通过公式计算出支持度,通过支持度去计算置信度,通过置信度计算提升度。当置信度大于一定阈值的时候则可以认为两件物品的出现具有一定规则,并且物品A对物品B的提升度也可以计算。

但是我来给大家画张图看看这种做法的成本有多大,假设有物品1~5,用圆圈表示:

apriori算法 机器学习 apriori算法应用_数据分析_04

大家可以看到,只有五个商品就可以织出那么恐怖的网络,所以遍历的方法是不可取的。

2.1 非频繁项集的超集也是非频繁项集

因为并不是所有的物品都是频繁的,所以我们需要指定一个最小支持度来筛选出频繁的物品,并且依靠标题这一句**“非频繁项集的超集也是非频繁项集”**,我们可以实现剪枝操作。

为什么非频繁项集的超集也是频繁项集,我们感性的去认识就可以了:连中两张500W的彩票已经够稀有了,如果再中一张那是不是更稀有,那就更不可能频繁了。

具体是怎么做的呢?在计算不同项集的支持度的时候需要我们添加判断,只有支持度大于最小支持度的物品才能进入下一轮的两两组合,那么非频繁项集就不会进入下一轮也就不需要计算其超集。这能帮助我们大大减少计算压力。

为了让大家好理解,我再来一张图给大家感受一下:

apriori算法 机器学习 apriori算法应用_算法_05

可以看到35这个2-项集是非频繁的,那么它的所有超集,以及超集的超集都不会是频繁的,那么在实际应用中就会少去很多项集需要计算组合等等。

3、 Apriori算法实现 - 7500行购物清单案例

接下来我们使用一个7500行的数据集来做Apriori算法的实现。

3.1 观察数据集

虾,杏仁,鳄梨,混合蔬菜,绿葡萄,全麦面粉,山药,农家干酪,功能饮料,番茄汁,低脂酸奶,绿茶,蜂蜜,沙拉,矿泉水,三文鱼,抗氧化剂果汁,冷冻果汁,菠菜,橄榄油
汉堡,肉丸,鸡蛋
酸辣酱
火鸡,鳄梨
矿泉水,牛奶,能量条,全麦大米,绿茶
低脂酸奶
全麦意大利面,炸薯条
汤,淡奶油,青葱
冷冻蔬菜,意大利细面条,绿茶
炸薯条
鸡蛋,宠物食品
饼干
火鸡,汉堡,矿泉水,鸡蛋,食用油
意大利细面条,香槟,饼干
矿泉水,三文鱼
矿泉水
虾,巧克力,鸡肉,蜂蜜,油,食用油,低脂酸奶
.....

数据放在txt文件内,数据用逗号隔开,每一个事务的长度都不同,并且数据的每一个元素都是中文。所以我们不建议使用pandas来处理,并且为了方便计算,可能需要把字符串转为数字来处理。(因为在下面的比较中数字比较方便)

3.2 读取数据

使用正常的方式读取文件,并且按逗号分隔,装进一个二维列表里:

def load_data(file_path):
    """
    加载数据集
    :param file_path: 文件路径
    :return:  data_list list类型
    """

    """
    若使用pandas来读的话会出现许多NaN,所以就直接读取文件
    """

    data_list = []

    with open(file_path, encoding="utf-8") as f:
        for line in f:
            line = line.strip("\n")
            data_list.append(line.split(','))
    # print(data_list)
    return data_list

3.3 数据转换

我们需要把字符串转为数字,并且其中包括去重:

def data_2_index(data_set):
    """
    把data_set中的字符串转为index

    :param data_set: 数据列表 list
    :return: data_set 数据列表 list
    """
    """
        for i in itertools.chain(*data_set):
        print(i)
        这一句可以看itertools.chain的效果        
    """
    # 把data_set拆包 然后取出其中的元素 通过set去重
    items = set(itertools.chain(*data_set))
    # print(items)
    # 保存字符串到编号的映射
    str_2_index = {}
    # 保存编号到字符串的映射
    index_2_str = {}
    for index, item in enumerate(items):
        # print(index, '->', item)
        str_2_index[item] = index
        index_2_str[index] = item

    for i in range(len(data_set)):
        for j in range(len(data_set[i])):
            data_set[i][j] = str_2_index[data_set[i][j]]

    # print(str_2_index)
    # print(index_2_str)
    # print(data_set)
    return data_set, index_2_str

3.4 计算频繁项集

前面完成了数据的处理,到这一步我们涉及到频繁项集的各种容器,在这我们需要约定从而方便代码阅读。约定C_k为k-候选项集,L_k为频繁项集。至于为什么需要候选项集,请看下图:

apriori算法 机器学习 apriori算法应用_python_06

候选项集:即这个集合没有经过支持度的计算,尚未知道其支持度从而不确定是否是频繁项集,通常是频繁项集间两两组合后生成(1-候选项集除外)

频繁项集:即经过支持度计算并通过最小支持度的筛选的项集。

通过循环判断频繁项集是否为空集可以知道接下来是否还有新的组合。那么**“非频繁项集的超集也是非频繁项集”这句话是如何实现的呢?**其实前面有提示,我们得到频繁项集后经过两两项集组合得到候选的k+1项集,然后有需要经过计算支持度查看是否满足最小支持度。若该组合k+1项集不满足最小支持度,即该项集为非频繁,那么就不应该进入频繁k+1项集中,也就不会进入k+2候选项集中。所以通过候选项集可以排除掉非频繁的项集。

3.4.1 首先计算候选1-项集

def build_c1(data_set):
    """
    创建候选1项集
    :param data_set: 数字化后的data_set
    :return:
    """
    # 把data_set中的元素去重
    items = set(itertools.chain(*data_set))
    # print(items)
    # 用frozenset把项集装进新列表里
    """
    Tips: 使用frozenset的原意是接下来的步骤需要使用items里的内容做key
    若直接将数字作为key的话也可以,但是后面还有生成二项集、三项集的操作,那就需要用list等来装,这样就不能作为key了
    
    即:
        my_dict = {}
        my_dict[frozenset([1, 2, 3])] = 2.2
        这个操作时可以的,打印my_dict是:{frozenset({1, 2, 3}): 2.2}
        
        my_dict = {}
        my_dict[[1, 2, 3]] = 2.2
        这个非操作是非法的,TypeError: unhashable type: 'list' 即list不能哈希
        
    
    当然,办法总比困难多,我试过将list转为str,将字符串作为key放入dict。这样也是可以,但是需要两个函数专门处理,
    并且这两个解析函数还需要根据不同的数据类型专门写。
    """
    frozen_items = [frozenset(i) for i in enumerate(items)]
    # print(frozen_items)
    return frozen_items

3.4.2 然后计算频繁1-项集

# 创建候选1项集
c1 = build_c1(data_set)
# 从候选1项集 到 频繁1项集
l1 = ck_2_lk(data_set, ck=c1, min_support=0.05)

得到c1后得到l1。

整个ck_2_lk函数就是计算候选k-项集的频繁k-项集:

def ck_2_lk(data_set, ck, min_support):
    """
    根据候选k项集生成频繁k项集,依据min_support
    :param data_set: 数据集 list类型
    :param ck: 候选k项集 list类型,list装frozenset
    :param min_support: float 最小支持度
    :return: lk dict类型
    """

    # 频数字典 用来记录每个项集出现的频数
    support = {}
    # 用数据集的每一行跟候选项集的每个项对比,若该项集是其中子集,则+1,否则为0
    for row in data_set:
        for item in ck:
            if item.issubset(row):
                support[item] = support.get(item, 0) + 1
    # print(support)
    # 计算频率需要用到长度
    length = len(data_set)
    lk = {}
    for key, value in support.items():
        # print(key, value)
        percent = value / length
        # 频率大于最小支持度才能进入频繁项集
        if percent > min_support:
            lk[key] = percent

    return lk

3.4.3 整体频繁项集的逻辑

这个函数就是复刻上面流程图的内容,需要注意的是Lk先等于l1:

def get_all_L(data_set, min_support):
    """
    把所有的频繁项集拿到
    :param data_set: 数据
    :param min_support:  最小支持度
    :return:
    """
    # 创建候选1项集
    c1 = build_c1(data_set)
    # 从候选1项集 到 频繁1项集
    l1 = ck_2_lk(data_set, ck=c1, min_support=0.05)
    L = l1
    Lk = l1
    while len(Lk) > 0:
        lk_key_list = list(Lk.keys())
        # 频繁k 到 候选k+1
        ck_plus_1 = lk_2_ck_plus_1(lk_key_list)
        # 候选k 到 频繁k
        Lk = ck_2_lk(data_set, ck_plus_1, min_support)
        if len(Lk) > 0:
            L.update(Lk)
        else:
            break
    return L

3.5 找到关联规则

关联规则需要使用置信度公式,并且需要通过最小置信度的筛选:

# 得到所有频繁项集
L = get_all_L(data_set, 0.05)

将所有频繁项集丢入函数中:

# 得到所有关联规则
result = rules_from_L(L, min_confidence=0.05)

函数内容:

def rules_from_L(L, min_confidence):
    # 保存所有候选的关联规则
    rules = []
    for Lk in L:
        # 频繁项集长度要大于1才能生成关联规则
        if len(Lk) > 1:
            rules.extend(rules_from_item(Lk))
    result = []
    for left, right in rules:
        # left和right都是frozenset类型 二者可以取并集 然后L里去查询支持度
        support = L[left | right]
        # 置信度公式
        confidence = support / L[left]
        lift = confidence / L[right]
        if confidence > min_confidence:
            result.append({"左": left, "右": right, "支持度": support, "置信度": confidence, "提升度": lift})

    return result

其实在这里我们已经得到了关联规则,但是由于前面把数据转为数字我们无法分辨,所以需要转换回来。

3.6 数据逆转换

索引转字符串:

# 把index转成具体商品名称
result = return_2_str(index_2_str,result)

具体内容:

def return_2_str(index_2_str, result):
    """
    把index转为具体商品名称
    :param index_2_str:  index:str 的dict
    :param result: 关联规则的list
    :return:
    """
    for item in result:
        left = list(item['左'])[0]
        true_left = index_2_str[left]
        right = list(item['右'])[0]
        true_right = index_2_str[right]
        item['左'] = frozenset({true_left})
        item['右'] = frozenset({true_right})
    return result

到此整个Apriori算法的手写做完了,并且使用一个较小的貌似真实的数据集来营造一个类似真实的应用场景。整个项目代码我会传到github上。

4、 算法优劣

虽说这个算法是关联性分析算法的祖师级算法,它的思路创新无可厚非,但是仍然存在许多缺陷:

  • 优点: 易编码实现
  • 缺点: 在大数据集上可能较慢
  • 适用数据类型: 数值型 或者 标称型数据。

项目地址

点击进入github