Cohort Analysis是将某一个时期内的用户划分为一个cohort,并将多个cohort进行时间上的某个属性的比较的一种分析方法。Cohort Analysis在有些场景下非常有用。比如一个网站或App,在某个连续的4周里陆续更新或新增了一个功能或设计,想要知道这些功能和设计上的改动对用户的影响,就可以将每周的新注册作为一个cohort,观察这4个cohort在接下来的一段实际里的行为数据,就可以很清楚地观察到4个改动的影响。

最近要做Cohort Analysis,数据都在数据库里,就直接想用Python接数据库直接分析了。Google到一篇讲得很清楚的文章,学着用Python实现一遍。示例数据从这里下载。

1. 读取数据

据说 from ... import ... 的方式在性能上有弊端,一般都推荐直接 import。示例数据是典型的购物数据,按客户的第一次消费时间将客户分为不同的cohort。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
df = pd.read_excel('relay-foods.xlsx', sheet='Purchase Data')
df.head()

OrderId

OrderDate

UserId

TotalCharges

CommonId

PupId

PickupDate

0

262

2009-01-11

47

50.67

TRQKD

2

2009-01-12

1

278

2009-01-20

47

26.60

4HH2S

3

2009-01-20

2

294

2009-02-03

47

38.71

3TRDC

2

2009-02-04

3

301

2009-02-06

47

53.38

NGAZJ

2

2009-02-09

4

302

2009-02-06

47

14.28

FFYHD

2

2009-02-09

2. 确定 OrderDate 的月份,根据 OrderDate 分群

分别创建两个字段 OrderPeriodCohortGroupOrderPeriod是购买日期的月份,CohortGroup是根据购买日期对 UserId 的分群。

df['OrderPeriod'] = df.OrderDate.map(lambda x: x.strftime('%Y-%m'))

df = df.assign(OrderPeriod = df.OrderDate.map(lambda x: x.strftime('%Y-%m'))) \
       .set_index('UserId')

# goupby(level=0): level 是index的level, 对于multiIndex, 可用level=0或1指定根据那个index来group
df = df.assign(CohortGroup = df.groupby(level=0).OrderDate.min().apply(lambda x: x.strftime('%Y-%m'))) \
       .reset_index()

df.head()

UserId

OrderId

OrderDate

TotalCharges

CommonId

PupId

PickupDate

OrderPeriod

CohortGroup

0

47

262

2009-01-11

50.67

TRQKD

2

2009-01-12

2009-01

2009-01

1

47

278

2009-01-20

26.60

4HH2S

3

2009-01-20

2009-01

2009-01

2

47

294

2009-02-03

38.71

3TRDC

2

2009-02-04

2009-02

2009-01

3

47

301

2009-02-06

53.38

NGAZJ

2

2009-02-09

2009-02

2009-01

4

47

302

2009-02-06

14.28

FFYHD

2

2009-02-09

2009-02

2009-01

3. 计算每个CohortGroup在各个OrderPeriod的用户量

# pd.Series.nunique --> Return number of unique elements in the object.
cohorts = df.groupby(['CohortGroup', 'OrderPeriod']) \
            .agg({'UserId': pd.Series.nunique,
                 'OrderId': pd.Series.nunique,
                 'TotalCharges': 'sum'})
cohorts.rename(columns={'UserId': 'TotalUsers', 'OrderId': 'TotalOrders'}, inplace=True)
cohorts.head()

TotalCharges

TotalOrders

TotalUsers

CohortGroup

OrderPeriod

2009-01

2009-01

1850.255

30

22

2009-02

1351.065

25

8

2009-03

1357.360

26

10

2009-04

1604.500

28

9

2009-05

1575.625

26

10

4. 标记每个CohortGroup的Cohort时期

比如,对于2009-01 cohort,其第一个时期是 2009-01,第二到第五个时期为 2009-02,...,2009-05。这里需要将每个CohortGroup的OrderPeriod对应到其第1,2,...个时期。

def cohort_period(df):
    df['CohortPeriod'] = np.arange(len(df)) + 1
    return df

cohorts = cohorts.groupby(level=0).apply(cohort_period)
cohorts.head()

TotalCharges

TotalOrders

TotalUsers

CohortPeriod

CohortGroup

OrderPeriod

2009-01

2009-01

1850.255

30

22

1

2009-02

1351.065

25

8

2

2009-03

1357.360

26

10

3

2009-04

1604.500

28

9

4

2009-05

1575.625

26

10

5

上面 level=0 实际上就是对 group by CohortGroup,然后对每个group结果 apply cohort_periodgroupby 后的结果是这样的:

[(k, v) for k, v in cohorts.head(5).groupby(level=0)]
[('2009-01',
                           TotalCharges  TotalOrders  TotalUsers  CohortPeriod
  CohortGroup OrderPeriod                                                     
  2009-01     2009-01          1850.255           30          22             1
              2009-02          1351.065           25           8             2
              2009-03          1357.360           26          10             3
              2009-04          1604.500           28           9             4
              2009-05          1575.625           26          10             5)]

5. 确定分群后的结果正确

x = df[(df.CohortGroup=='2009-01') & (df.OrderPeriod=='2009-01')]
y = cohorts.ix[('2009-01', '2009-01')]

assert(x.UserId.nunique()==y.TotalUsers)
assert(x.OrderId.nunique()==y.TotalOrders)
assert(x.TotalCharges.sum()==y.TotalCharges)

x = df[(df.CohortGroup=='2009-03') & (df.OrderPeriod=='2009-05')]
y = cohorts.ix[('2009-03', '2009-05')]

assert(x.UserId.nunique()==y.TotalUsers)
assert(x.OrderId.nunique()==y.TotalOrders)
assert(x.TotalCharges.sum()==y.TotalCharges)

每个CohortGroup的留存

6. 计算每个CohortGroup在第一个CohortPeriod的用户数量

cohorts = cohorts.reset_index() \
                 .set_index(['CohortGroup', 'CohortPeriod'])

cohort_group_size = cohorts.TotalUsers.groupby(level=0).first()
cohort_group_size
CohortGroup
2009-01     22
2009-02     15
2009-03     13
2009-04     39
2009-05     50
2009-06     32
2009-07     50
2009-08     31
2009-09     37
2009-10     54
2009-11    130
2009-12     65
2010-01     95
2010-02    100
2010-03     24
Name: TotalUsers, dtype: int64

7. 计算每个CohortPeriod的留存率

user_retention = cohorts.TotalUsers.unstack(0).divide(cohort_group_size, axis=1)
user_retention.head()

CohortGroup

2009-01

2009-02

2009-03

2009-04

2009-05

2009-06

2009-07

2009-08

2009-09

2009-10

2009-11

2009-12

2010-01

2010-02

2010-03

CohortPeriod

1

1.000000

1.000000

1.000000

1.000000

1.00

1.00000

1.00

1.000000

1.000000

1.000000

1.000000

1.000000

1.000000

1.00

1.0

2

0.363636

0.200000

0.307692

0.333333

0.26

0.46875

0.46

0.354839

0.405405

0.314815

0.246154

0.261538

0.526316

0.19

NaN

3

0.454545

0.333333

0.384615

0.256410

0.24

0.28125

0.26

0.290323

0.378378

0.222222

0.200000

0.276923

0.273684

NaN

NaN

4

0.409091

0.066667

0.307692

0.333333

0.10

0.18750

0.20

0.225806

0.216216

0.240741

0.223077

0.107692

NaN

NaN

NaN

5

0.454545

0.266667

0.076923

0.153846

0.08

0.21875

0.22

0.193548

0.351351

0.240741

0.100000

NaN

NaN

NaN

NaN

留存率曲线

user_retention[['2009-01', '2009-05', '2009-08']] \
  .plot(figsize=(11, 6), color=['#4285f4', '#EA4335', '#A60628'])
plt.title("Cohorts: User Retention")
plt.xticks(np.arange(1, len(user_retention)+1, 1))
plt.xlim(1, len(user_retention))
plt.ylabel('% of Cohort Purchasing', fontsize=16)

python 影响分析 python analysis_python 影响分析

留存率热力图

import seaborn as sns
sns.set(style='white')

plt.figure(figsize=(16, 8))
plt.title('Cohorts: User Retetion', fontsize=14)
sns.heatmap(user_retention.T, 
            mask=user_retention.T.isnull(), 
            annot=True, fmt='.0%')

python 影响分析 python analysis_User_02