文章目录

  • 项目1-PM2.5预测
  • 友情提示
  • 项目描述
  • 数据集介绍
  • 项目要求
  • 数据准备
  • 环境配置/安装
  • **预处理**
  • **提取特征 (1)**
  • **提取特征 (2)**
  • **归一化**
  • **将训练数据分割成 "训练集 "和 "验证集"**
  • **训练**
  • **测试**
  • **预测**
  • **保存预测到CSV文件**


项目1-PM2.5预测

友情提示

同学们可以前往课程作业区先行动手尝试!!!

项目描述

  • 本次作业的资料是从行政院环境环保署空气品质监测网所下载的观测资料。
  • 希望大家能在本作业实现 linear regression 预测出 PM2.5 的数值。

数据集介绍

  • 本次作业使用丰原站的观测记录,分成 train set 跟 test set,train set 是丰原站每个月的前 20 天所有资料。test set 则是从丰原站剩下的资料中取样出来。
  • train.csv: 每个月前 20 天的完整资料。
  • test.csv : 从剩下的资料当中取样出连续的 10 小时为一笔,前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。一共取出 240 笔不重複的 test data,请根据 feature 预测这 240 笔的 PM2.5。
  • Data 含有 18 项观测数据 AMB_TEMP, CH4, CO, NHMC, NO, NO2, NOx, O3, PM10, PM2.5, RAINFALL, RH, SO2, THC, WD_HR, WIND_DIREC, WIND_SPEED, WS_HR。

项目要求

  • 请手动实现 linear regression,方法限使用 gradient descent。
  • 禁止使用 numpy.linalg.lstsq

数据准备

环境配置/安装

!pip install --upgrade pandas
import sys
import pandas as pd
import numpy as np

data = pd.read_csv('work/hw1_data/train.csv', encoding = 'big5')
print(pd.__version__)
1.3.5

预处理

取需要的数值部分,将 ‘RAINFALL’ 栏位全部补 0。

下图为全流程图:

python 判断是否是rtp协议_pandas

# 训练集的数据形式为:4320*24,每18行为1天的数据,共计4320/18=240天的数据。每天监测24小时的数据,列号即为此时的监测时间。
print(data.head(20))
print(data.shape)

# 去除前三列无效数据
data = data.iloc[:, 3:]
# 创造列名为'NR'的列,该列数据赋初值为0
data[data == 'NR'] = 0
raw_data = data.to_numpy()
日期  測站          測項     0     1     2     3     4     5     6  ...  \
0   2014/1/1  豐原    AMB_TEMP    14    14    14    13    12    12    12  ...   
1   2014/1/1  豐原         CH4   1.8   1.8   1.8   1.8   1.8   1.8   1.8  ...   
2   2014/1/1  豐原          CO  0.51  0.41  0.39  0.37  0.35   0.3  0.37  ...   
3   2014/1/1  豐原        NMHC   0.2  0.15  0.13  0.12  0.11  0.06   0.1  ...   
4   2014/1/1  豐原          NO   0.9   0.6   0.5   1.7   1.8   1.5   1.9  ...   
5   2014/1/1  豐原         NO2    16   9.2   8.2   6.9   6.8   3.8   6.9  ...   
6   2014/1/1  豐原         NOx    17   9.8   8.7   8.6   8.5   5.3   8.8  ...   
7   2014/1/1  豐原          O3    16    30    27    23    24    28    24  ...   
8   2014/1/1  豐原        PM10    56    50    48    35    25    12     4  ...   
9   2014/1/1  豐原       PM2.5    26    39    36    35    31    28    25  ...   
10  2014/1/1  豐原    RAINFALL    NR    NR    NR    NR    NR    NR    NR  ...   
11  2014/1/1  豐原          RH    77    68    67    74    72    73    74  ...   
12  2014/1/1  豐原         SO2   1.8     2   1.7   1.6   1.9   1.4   1.5  ...   
13  2014/1/1  豐原         THC     2     2     2   1.9   1.9   1.8   1.9  ...   
14  2014/1/1  豐原       WD_HR    37    80    57    76   110   106   101  ...   
15  2014/1/1  豐原  WIND_DIREC    35    79   2.4    55    94   116   106  ...   
16  2014/1/1  豐原  WIND_SPEED   1.4   1.8     1   0.6   1.7   2.5   2.5  ...   
17  2014/1/1  豐原       WS_HR   0.5   0.9   0.6   0.3   0.6   1.9     2  ...   
18  2014/1/2  豐原    AMB_TEMP    16    15    15    14    14    15    16  ...   
19  2014/1/2  豐原         CH4   1.8   1.8   1.8   1.8   1.8   1.8   1.8  ...   

      14    15    16    17    18    19    20    21    22    23  
0     22    22    21    19    17    16    15    15    15    15  
1    1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  
2   0.37  0.37  0.47  0.69  0.56  0.45  0.38  0.35  0.36  0.32  
3    0.1  0.13  0.14  0.23  0.18  0.12   0.1  0.09   0.1  0.08  
4    2.5   2.2   2.5   2.3   2.1   1.9   1.5   1.6   1.8   1.5  
5     11    11    22    28    19    12   8.1     7   6.9     6  
6     14    13    25    30    21    13   9.7   8.6   8.7   7.5  
7     65    64    51    34    33    34    37    38    38    36  
8     52    51    66    85    85    63    46    36    42    42  
9     36    45    42    49    45    44    41    30    24    13  
10    NR    NR    NR    NR    NR    NR    NR    NR    NR    NR  
11    47    49    56    67    72    69    70    70    70    69  
12   3.9   4.4   9.9   5.1   3.4   2.3     2   1.9   1.9   1.9  
13   1.9   1.9   1.9   2.1     2   1.9   1.9   1.9   1.9   1.9  
14   307   304   307   124   118   121   113   112   106   110  
15   313   305   291   124   119   118   114   108   102   111  
16   2.5   2.2   1.4   2.2   2.8     3   2.6   2.7   2.1   2.1  
17   2.1   2.1   1.9     1   2.5   2.5   2.8   2.6   2.4   2.3  
18    24    24    23    21    20    19    18    18    18    18  
19   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  

[20 rows x 27 columns]
(4320, 27)

提取特征 (1)

下图为全流程图:

python 判断是否是rtp协议_数据_02

将原始 4320 * 18 的资料依照每个月份重组成 12 个 18 (特征) * 480 (小时) 的资料。

# 将原始 240天*18个特征*24小时 的资料依照每个月份重组成 12个月(每月20天)*18个特征*480 (小时) 的资料。
month_data = {}
for month in range(12):
    sample = np.empty([18, 480])
    for day in range(20):
        sample[:, day * 24 : (day + 1) * 24] = raw_data[18 * (20 * month + day) : 18 * (20 * month + day + 1), :]
    month_data[month] = sample

提取特征 (2)

下图为全流程图:

python 判断是否是rtp协议_pandas

每个月会有 480小时,每 9 小时形成一个数据,每个月会有 471 个数据,故总资料数为 471 * 12 笔,而每笔 数据 有 9 * 18 的 特征 (一小时 18 个 特征 * 9 小时)。

对应的 目标 则有 471 * 12 个(第 10 个小时的 PM2.5)

# x为训练集中的数据笔数12 * 471,,每一笔数据都是18(特征数) * 9(小时)
x = np.empty([12 * 471, 18 * 9], dtype = float)

# y为每笔训练集数据中第10小时的PM2.5
y = np.empty([12 * 471, 1], dtype = float)

for month in range(12):
    for day in range(20):
        for hour in range(24):
            # 判断当前数据是否为每月的最后一笔,也就是该月第19天的第15个小时。
            # reshape(1, -1)将每一笔数据都是18(特征数) * 9(小时)的数据转换为1行数据
            if day == 19 and hour > 14:
                continue
            x[month * 471 + day * 24 + hour, :] = month_data[month][:,day * 24 + hour : day * 24 + hour + 9].reshape(1, -1) #vector dim:18*9 (9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9)
            y[month * 471 + day * 24 + hour, 0] = month_data[month][9, day * 24 + hour + 9] #value
print(x)
print(y)
[[14.  14.  14.  ...  2.   2.   0.5]
 [14.  14.  13.  ...  2.   0.5  0.3]
 [14.  13.  12.  ...  0.5  0.3  0.8]
 ...
 [17.  18.  19.  ...  1.1  1.4  1.3]
 [18.  19.  18.  ...  1.4  1.3  1.6]
 [19.  18.  17.  ...  1.3  1.6  1.8]]
[[30.]
 [41.]
 [44.]
 ...
 [17.]
 [24.]
 [29.]]

归一化

下图为全流程图:

python 判断是否是rtp协议_pandas

# axis=0,那么输出矩阵是1行,求每一列的平均;axis=1,输出矩阵是1列,求每一行的平均
# 也就是求每一列12*471笔数据的均值,然后输出一行18*9的均值矩阵
mean_x = np.mean(x, axis = 0) #18 * 9 
std_x = np.std(x, axis = 0) #18 * 9 

# Z-Score Normalization(Z分数归一化)
for i in range(len(x)): #12 * 471
    for j in range(len(x[0])): #18 * 9 
        if std_x[j] != 0:
            x[i][j] = (x[i][j] - mean_x[j]) / std_x[j]
x
array([[-1.35825331, -1.35883937, -1.359222  , ...,  0.26650729,
         0.2656797 , -1.14082131],
       [-1.35825331, -1.35883937, -1.51819928, ...,  0.26650729,
        -1.13963133, -1.32832904],
       [-1.35825331, -1.51789368, -1.67717656, ..., -1.13923451,
        -1.32700613, -0.85955971],
       ...,
       [-0.88092053, -0.72262212, -0.56433559, ..., -0.57693779,
        -0.29644471, -0.39079039],
       [-0.7218096 , -0.56356781, -0.72331287, ..., -0.29578943,
        -0.39013211, -0.1095288 ],
       [-0.56269867, -0.72262212, -0.88229015, ..., -0.38950555,
        -0.10906991,  0.07797893]])

将训练数据分割成 "训练集 "和 “验证集”

这部分是简单示范,以生成比较中用来训练的训练集和不会被放入训练、只是用来验证的验证集。

下图为全流程图:

python 判断是否是rtp协议_pandas

# x的长度为行数12*471=5652
print(len(x))
5652
import math
# Math.floor()为向下取整
# 将训练数据分割成 "训练集 "和 "验证集",比例为8:2
x_train_set = x[: math.floor(len(x) * 0.8), :]
y_train_set = y[: math.floor(len(y) * 0.8), :]
x_validation = x[math.floor(len(x) * 0.8): , :]
y_validation = y[math.floor(len(y) * 0.8): , :]

print(x_train_set)
print(y_train_set)
print(x_validation)
print(y_validation)

# 训练集为4521行,测试集为1131行
print(len(x_train_set))
print(len(y_train_set))
print(len(x_validation))
print(len(y_validation))
[[-1.35825331 -1.35883937 -1.359222   ...  0.26650729  0.2656797
  -1.14082131]
 [-1.35825331 -1.35883937 -1.51819928 ...  0.26650729 -1.13963133
  -1.32832904]
 [-1.35825331 -1.51789368 -1.67717656 ... -1.13923451 -1.32700613
  -0.85955971]
 ...
 [ 0.86929969  0.70886668  0.38952809 ...  1.39110073  0.2656797
  -0.39079039]
 [ 0.71018876  0.39075806  0.07157353 ...  0.26650729 -0.39013211
  -0.39079039]
 [ 0.3919669   0.07264944  0.07157353 ... -0.38950555 -0.39013211
  -0.85955971]]
  
  
[[30.]
 [41.]
 [44.]
 ...
 [ 7.]
 [ 5.]
 [14.]]
 
 
[[ 0.07374504  0.07264944  0.07157353 ... -0.38950555 -0.85856912
  -0.57829812]
 [ 0.07374504  0.07264944  0.23055081 ... -0.85808615 -0.57750692
   0.54674825]
 [ 0.07374504  0.23170375  0.23055081 ... -0.57693779  0.54674191
  -0.1095288 ]
 ...
 [-0.88092053 -0.72262212 -0.56433559 ... -0.57693779 -0.29644471
  -0.39079039]
 [-0.7218096  -0.56356781 -0.72331287 ... -0.29578943 -0.39013211
  -0.1095288 ]
 [-0.56269867 -0.72262212 -0.88229015 ... -0.38950555 -0.10906991
   0.07797893]]
   
   
[[13.]
 [24.]
 [22.]
 ...
 [17.]
 [24.]
 [29.]]
 
 
4521
4521
1131
1131

训练

下图为全流程图:

python 判断是否是rtp协议_pandas

因为常数项的存在,所以 维度 (dim) 需要多加一栏;eps 项是避免 adagrad 的分母为 0 而加的极小数值。

每一个 维度 (dim) 会对应到各自的梯度, 权重 (w),透过一次次的迭代 (iter_time) 学习。

# x+y的列数
dim = 18 * 9 + 1
w = np.zeros([dim, 1])
x = np.concatenate((np.ones([12 * 471, 1]), x), axis = 1).astype(float)
learning_rate = 100
iter_time = 1000
adagrad = np.zeros([dim, 1])
eps = 0.0000000001

# loss函数使用的是:RMSE(Root Mean Square Error)均方根误差
# 均方根误差是预测值与真实值偏差的平方与观测次数n比值的平方根。
# 衡量的是预测值与真实值之间的偏差,并且对数据中的异常值较为敏感。
for t in range(iter_time):
    loss = np.sqrt(np.sum(np.power(np.dot(x, w) - y, 2))/471/12)
    # 每迭代100次,打印出当前的loss值
    if(t%100==0):
        print(str(t) + ":" + str(loss))
    gradient = 2 * np.dot(x.transpose(), np.dot(x, w) - y) #dim*1
    adagrad += gradient ** 2
    w = w - learning_rate * gradient / np.sqrt(adagrad + eps)
# 保存权重文件w
np.save('work/weight.npy', w)
0:27.071214829194115
100:33.78905859777454
200:19.9137512981971
300:13.531068193689693
400:10.645466158446174
500:9.277353455475065
600:8.518042045956502
700:8.014061987588423
800:7.636756824775692
900:7.336563740371124





array([[ 2.13740269e+01],
       [ 3.58888909e+00],
       [ 4.56386323e+00],
       [ 2.16307023e+00],
       .... 
       [-4.23463160e-01],
       [ 4.89922051e-01]])

测试

下图为全流程图:

python 判断是否是rtp协议_pandas

载入测试数据,并且以相似于训练资料预先处理和特徵萃取的方式处理,使测试数据形成 240 个维度为 18 * 9 + 1 的资料。

testdata = pd.read_csv('work/hw1_data/test.csv', header = None, encoding = 'big5')
test_data = testdata.iloc[:, 2:]
test_data[test_data == 'NR'] = 0
test_data = test_data.to_numpy()

test_x = np.empty([240, 18*9], dtype = float)
for i in range(240):
    test_x[i, :] = test_data[18 * i: 18* (i + 1), :].reshape(1, -1)

# Z分数归一化
for i in range(len(test_x)):
    for j in range(len(test_x[0])):
        if std_x[j] != 0:
            test_x[i][j] = (test_x[i][j] - mean_x[j]) / std_x[j]
            
test_x = np.concatenate((np.ones([240, 1]), test_x), axis = 1).astype(float)
test_x
array([[ 1.        , -0.24447681, -0.24545919, ..., -0.67065391,
        -1.04594393,  0.07797893],
       [ 1.        , -1.35825331, -1.51789368, ...,  0.17279117,
        -0.10906991, -0.48454426],
       [ 1.        ,  1.5057434 ,  1.34508393, ..., -1.32666675,
        -1.04594393, -0.57829812],
       ...,
       [ 1.        ,  0.3919669 ,  0.54981237, ...,  0.26650729,
        -0.20275731,  1.20302531],
       [ 1.        , -1.8355861 , -1.8360023 , ..., -1.04551839,
        -1.13963133, -1.14082131],
       [ 1.        , -1.35825331, -1.35883937, ...,  2.98427476,
         3.26367657,  1.76554849]])

预测

下图为全流程图:

python 判断是否是rtp协议_pandas

有了 权重 和测试资料即可预测 目标。

# 将原测试集*权重w,进行预测
w = np.load('work/weight.npy')
ans_y = np.dot(test_x, w)
array([[ 5.17496040e+00],
       [ 1.83062143e+01],
       .....
       [ 1.43137440e+01],
       [ 1.57707266e+01]])

保存预测到CSV文件

下图为全流程图:

python 判断是否是rtp协议_pandas

import csv
with open('work/submit.csv', mode='w', newline='') as submit_file:
    csv_writer = csv.writer(submit_file)
    header = ['id', 'value']
    print(header)
    csv_writer.writerow(header)
    for i in range(240):
        row = ['id_' + str(i), ans_y[i][0]]
        csv_writer.writerow(row)
        print(row)
['id', 'value']
['id_0', 5.174960398984726]
['id_1', 18.306214253527884]
['id_2', 20.491218094180542]
['id_3', 11.523942869805353]
['id_4', 26.61605675230616]
['id_5', 20.53134808176122]
['id_6', 21.906551018797376]
['id_7', 31.73646874706883]
['id_8', 13.391674055111736]
['id_9', 64.45646650291955]
['id_10', 20.264568836159466]
['id_11', 15.358576077361217]
['id_12', 68.58947276926725]
['id_13', 48.428113747457196]
['id_14', 18.702333824193207]
['id_15', 10.188595737466706]
['id_16', 30.74036285982042]
['id_17', 71.1322177635511]
['id_18', -4.130517391262453]
['id_19', 18.23569401642868]
['id_20', 38.57892227500775]
['id_21', 71.31151972531332]
['id_22', 7.410348162634064]
['id_23', 18.71795533032141]
['id_24', 14.937250260084582]
['id_25', 36.719736694705325]
['id_26', 17.961697005662696]
['id_27', 75.78946287210537]
['id_28', 12.309310248614484]
['id_29', 56.2953517396496]
['id_30', 25.113160865661484]
['id_31', 4.6102486740940325]
['id_32', 2.483770554515017]
['id_33', 24.759422261321284]
['id_34', 30.48028046559118]
['id_35', 38.463930746426634]
['id_36', 44.20231060933005]
['id_37', 30.086836019866013]
['id_38', 40.4736750157401]
['id_39', 29.22647990231738]
['id_40', 5.606456054343926]
['id_41', 38.666016078789596]
['id_42', 34.610213431877206]
['id_43', 48.38969750738481]
['id_44', 14.75724766694418]
['id_45', 34.46682011087208]
['id_46', 27.48310687418435]
['id_47', 12.000879378154064]
['id_48', 21.378036151603776]
['id_49', 28.54440309166329]
['id_50', 20.16551381841159]
['id_51', 10.796678149746501]
['id_52', 22.171035755750147]
['id_53', 53.44626310935228]
['id_54', 12.219581121610014]
['id_55', 43.30096845517152]
['id_56', 32.182335103285425]
['id_57', 22.567217514570817]
['id_58', 56.739514165547035]
['id_59', 20.745052945295463]
['id_60', 15.02885455747326]
['id_61', 39.8553015903851]
['id_62', 12.975340680728287]
['id_63', 51.74165959283004]
['id_64', 18.783369632539888]
['id_65', 12.348752842777701]
['id_66', 15.633623653541909]
['id_67', -0.058871470685014415]
['id_68', 41.50801107307594]
['id_69', 31.548747530656033]
['id_70', 18.60425115754707]
['id_71', 37.4768197248807]
['id_72', 56.52039065762305]
['id_73', 6.587877193521951]
['id_74', 12.229339737435023]
['id_75', 5.203696404134652]
['id_76', 47.92737510380062]
['id_77', 13.020705685594672]
['id_78', 17.110301693903615]
['id_79', 20.60323453100204]
['id_80', 21.2844815607846]
['id_81', 38.6929352905118]
['id_82', 30.02071667572584]
['id_83', 88.76740666723548]
['id_84', 35.98470023966825]
['id_85', 26.756913553477172]
['id_86', 23.963516843564452]
['id_87', 32.747242828083074]
['id_88', 22.189043755319933]
['id_89', 20.99215885362656]
['id_90', 29.555994316645464]
['id_91', 40.99216886651781]
['id_92', 8.625117809911576]
['id_93', 32.3214718088779]
['id_94', 46.59804436536758]
['id_95', 22.884070826723534]
['id_96', 31.518129728251637]
['id_97', 11.198233479766111]
['id_98', 28.527436642529615]
['id_99', 0.29115068008963196]
['id_100', 17.966961079539693]
['id_101', 27.124163929470143]
['id_102', 11.398232780652853]
['id_103', 16.426426865673516]
['id_104', 23.425261046922163]
['id_105', 40.616082670568396]
['id_106', 25.8641250265604]
['id_107', 5.422736951672383]
['id_108', 10.794921122256119]
['id_109', 72.86213692992126]
['id_110', 48.022837059481404]
['id_111', 15.74680827690301]
['id_112', 24.670410614177968]
['id_113', 12.827793326536716]
['id_114', 10.158057570240517]
['id_115', 27.269223342020993]
['id_116', 29.20873857793244]
['id_117', 8.835339619930764]
['id_118', 20.05108813712979]
['id_119', 20.212333743764248]
['id_120', 79.90600929870556]
['id_121', 18.06161428826361]
['id_122', 30.542809341304327]
['id_123', 25.980792377728037]
['id_124', 5.212577268164774]
['id_125', 30.355697305856207]
['id_126', 7.7683228889146445]
['id_127', 15.328268255393342]
['id_128', 22.666365717697975]
['id_129', 62.74205421109008]
['id_130', 18.950780367988]
['id_131', 19.07635563083853]
['id_132', 61.371574091637115]
['id_133', 15.884562052629704]
['id_134', 13.409418077705558]
['id_135', 0.8487724836112854]
['id_136', 7.8349967173041435]
['id_137', 57.01282901179679]
['id_138', 25.607996751813825]
['id_139', 4.961704729242083]
['id_140', 36.41487903906277]
['id_141', 28.790006721975924]
['id_142', 49.194120961976346]
['id_143', 40.3068698557345]
['id_144', 13.316180593982665]
['id_145', 27.66101187522916]
['id_146', 17.15802752436674]
['id_147', 49.68726256929681]
['id_148', 23.030272291604767]
['id_149', 39.24093652484275]
['id_150', 13.196753889412534]
['id_151', 5.94889370103942]
['id_152', 25.82160897630425]
['id_153', 8.25863421429164]
['id_154', 19.14632051722559]
['id_155', 43.18248652651675]
['id_156', 6.717843578093027]
['id_157', 33.86961524681065]
['id_158', 15.369937846981808]
['id_159', 16.93904497355194]
['id_160', 37.88533679463485]
['id_161', 19.20248454105448]
['id_162', 9.059504715654713]
['id_163', 10.283399610648498]
['id_164', 48.672447125698284]
['id_165', 30.587716213230834]
['id_166', 2.477409897532155]
['id_167', 12.81160393780593]
['id_168', 70.32478980976464]
['id_169', 14.840967694067043]
['id_170', 68.86558756678863]
['id_171', 42.741992444866334]
['id_172', 24.00026154292016]
['id_173', 23.420724860321442]
['id_174', 61.67212443568237]
['id_175', 25.4942028450592]
['id_176', 19.004809786869068]
['id_177', 34.88668288189681]
['id_178', 9.402313398379727]
['id_179', 29.520011314408027]
['id_180', 14.573965885700494]
['id_181', 9.125563143203577]
['id_182', 52.81258399813187]
['id_183', 45.039537994389605]
['id_184', 17.45243467918329]
['id_185', 38.49393527971432]
['id_186', 27.03891909264383]
['id_187', 65.58170967424581]
['id_188', 7.037306380769563]
['id_189', 52.71447713411569]
['id_190', 38.2064593370498]
['id_191', 21.16980105955784]
['id_192', 30.2475568794884]
['id_193', 2.714422989716311]
['id_194', 19.93293258764082]
['id_195', -3.413332337603926]
['id_196', 32.44599940281316]
['id_197', 10.582973029979941]
['id_198', 21.77522570725845]
['id_199', 62.465292065677886]
['id_200', 24.132943687316462]
['id_201', 26.20123964740095]
['id_202', 63.744477234402886]
['id_203', 2.834297774129029]
['id_204', 14.37924698697884]
['id_205', 9.369850731753909]
['id_206', 9.881166613595404]
['id_207', 3.494945358972141]
['id_208', 122.6080493792178]
['id_209', 21.08351301448056]
['id_210', 17.53222059945514]
['id_211', 20.183098344596996]
['id_212', 36.39313221228184]
['id_213', 34.93515120529069]
['id_214', 18.83031266145862]
['id_215', 38.34455552272332]
['id_216', 77.9166341380704]
['id_217', 1.7953235508882321]
['id_218', 13.445827939135775]
['id_219', 36.131155590412135]
['id_220', 15.150403498166302]
['id_221', 12.941848334417905]
['id_222', 113.1252409378639]
['id_223', 15.224604677934368]
['id_224', 14.824025968612045]
['id_225', 59.26735368854045]
['id_226', 10.58369529071846]
['id_227', 20.993062563532174]
['id_228', 9.789365880830378]
['id_229', 4.77118000870598]
['id_230', 47.92780690481291]
['id_231', 12.399438394751055]
['id_232', 48.146476562644125]
['id_233', 40.46638039656413]
['id_234', 16.940590270332937]
['id_235', 41.266544489418735]
['id_236', 69.027892033729]
['id_237', 40.34624924412243]
['id_238', 14.313743982871172]
['id_239', 15.770726634219809]

以上打印的部分主要是为了看一下资料和结果的呈现,拿掉也无妨。另外,在自己的 linux 系统,可以将档案写死的的部分换成 sys.argv 的使用 (可在 终端 自行输入档案和档案位置)。
最后,可以藉由调整 (learning_rate)学习率、iter_time (迭代次数)、取用特征的多少(取几个小时,取哪些特征栏位),甚至是不同的 模型来超越基线。