def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return mat0, mat1
def loadDataSet(fileName):
dataMat = []
file = open(fileName)
for line in file.readlines():
lineMat = line.strip().split('\t')
lineArr = map(float, lineMat)
dataMat.append(lineArr)
return dataMat
def regLeaf(dataSet):
return mean(dataSet[:, -1])
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
tolS = ops[0]#容许误差下降值
tolN = ops[1]#最小样本
n = shape(dataSet)[1]
if len(set(dataSet[:, -1].T.tolist()[0]))==1:
return None, leafType(dataSet)
S = errType(dataSet)
bestS = inf
bestf = 0
bestv = 0
for i in range(n-1):
for j in set(dataSet[:, i].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, i, j)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestf = i
bestv = j
bestS = newS
if (S - bestS) < tolS:
print '***1***'
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestf, bestv)
if (shape(mat0)[0] < tolN )or( shape(mat1)[0] < tolN):
print '***2***'
return None, leafType(dataSet)
return bestf, bestv
def creatTree(dataSet, leafType = regLeaf, errType = regErr, ops = (0, 1)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None:
return val
retTree={}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = creatTree(lSet, leafType, errType, ops)
retTree['right'] = creatTree(rSet, leafType, errType, ops)
return retTree