# coding: utf-8
import collections
import numpy as np
import os
from sklearn.neighbors import NearestNeighbors


def cos(vector1,vector2):
    dot_product = 0.0;
    normA = 0.0;
    normB = 0.0;
    for a,b in zip(vector1,vector2):
        dot_product += a*b
        normA += a**2
        normB += b**2
    if normA == 0.0 or normB==0.0:
        return None
    else:
        return dot_product / ((normA*normB)**0.5)


def iterbrowse(path):
    for home, dirs, files in os.walk(path):
        for filename in files:
            yield os.path.join(home, filename)


def get_data(filename):
    white_verify = []
    with open(filename) as f:
        lines = f.readlines()
        for line in lines:
            a = line.split("\t")
            if len(a) != 78:
                print(line)
                raise Exception("fuck")
            white_verify.append([float(n) for n in a[3:]])
    return white_verify

unwanted_features = {6, 7, 8, 41,42,43,67,68,69,70,71,72,73,74,75}

def get_wanted_data(x):
    return x
    """
    ans = []
    for item in x:
        #row = [data for i, data in enumerate(item) if i+6 in wanted_feature]
        row = [data for i, data in enumerate(item) if i+6 not in unwanted_features]
        ans.append(row)
        #assert len(row) == len(wanted_feature)
        assert len(row) == len(x[0])-len(unwanted_features)
    return ans
    """


if __name__ == "__main__":
    neg_file = "cc_data/black/black_all.txt"
    pos_file = "cc_data/white/white_all.txt"
    X = []
    y = []
    # if os.path.isfile(pos_file):
    #     if pos_file.endswith('.txt'):
    #         pos_set = np.genfromtxt(pos_file)
    #     elif pos_file.endswith('.npy'):
    #         pos_set = np.load(pos_file)
    #     X.extend(pos_set)
    #     y += [0] * len(pos_set)
    # print("len of X(white):", len(X))
    if os.path.isfile(neg_file):
        if neg_file.endswith('.txt'):
            neg_set = np.genfromtxt(neg_file)
        elif neg_file.endswith('.npy'):
            neg_set = np.load(neg_file)
        X.extend(list(neg_set) * 1)
        y += [1] * (1 * len(neg_set))
    print("len of X:", len(X))
    # print("X sample:", X[:3])
    # print("len of y:", len(y))
    # print("y sample:", y[:3])
    X = [x[3:] for x in X]
    X = get_wanted_data(X)
    # print("filtered X sample:", X[:3])

    black_verify = []
    for f in iterbrowse("todo/top"):
        print(f)
        black_verify += get_data(f)
    # print(black_verify)
    black_verify = get_wanted_data(black_verify)
    black_verify_labels = [1] * len(black_verify)

    white_verify = get_data("todo/white_verify.txt")
    # print(white_verify)
    white_verify = get_wanted_data(white_verify)
    white_verify_labels = [0] * len(white_verify)

    unknown_verify = get_data("todo/pek_feature74.txt")
    unknown_verify = get_wanted_data(unknown_verify)

    bd_verify = get_data("guzhaoshen_pek_out.txt")
    # print(unknown_verify)

    # samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
    #neigh = NearestNeighbors(n_neighbors=3)
    neigh = NearestNeighbors(n_neighbors=1, metric='cosine')
    neigh.fit(X)

    print("neigh.kneighbors(black_verify)")
    nearest_points = (neigh.kneighbors(black_verify))
    print(nearest_points)
    for i, x in enumerate(black_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))

    #print(neigh.predict(black_verify))
    print("neigh.kneighbors(white_verify)")
    nearest_points = (neigh.kneighbors(white_verify))
    print(nearest_points)
    for i, x in enumerate(white_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))

    #print(neigh.predict(white_verify))
    print("neigh.kneighbors(unknown_verify)")
    nearest_points = (neigh.kneighbors(unknown_verify))
    print(nearest_points)
    for i, x in enumerate(unknown_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))

    #print(neigh.predict(unknown_verify))
    print("neigh.kneighbors(self)")
    print(neigh.kneighbors(X[:3]))

    #print(neigh.predict(X[:3]))
    print("neigh.kneighbors(bd pek)")
    print(neigh.kneighbors(bd_verify))

    nearest_points = (neigh.kneighbors(bd_verify))
    print(nearest_points)
    for i, x in enumerate(bd_verify):
        print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))

 输出示例:

neigh.kneighbors(white_verify)
(array([[ 0.01140831],
       [ 0.0067373 ],
       [ 0.00198682],
       [ 0.00686728],
       [ 0.00210445],
       [ 0.00061413],
       [ 0.00453888]]), array([[11032],
       [  967],
       [11091],
       [13149],
       [11091],
       [19041],
       [13068]]))
(0, array([11032]), 'cosine:', 1.0)
(1, array([967]), 'cosine:', 1.0)
(2, array([11091]), 'cosine:', 1.0)
(3, array([13149]), 'cosine:', 1.0)
(4, array([11091]), 'cosine:', 1.0)
(5, array([19041]), 'cosine:', 1.0)
(6, array([13068]), 'cosine:', 1.0)

样本质量堪忧啊!!!

 

注意:如果是常规knn,计算距离时候记得标准化。如果各个维度的数据属性衡量单位不一样:

from sklearn import preprocessing
    scaler = preprocessing.StandardScaler().fit(X)
    X = scaler.transform(X)
    print("standard X sample:", X[:3])

    black_verify = scaler.transform(black_verify)
    print(black_verify)

    white_verify = scaler.transform(white_verify)
    print(white_verify)

    unknown_verify = scaler.transform(unknown_verify)
    print(unknown_verify)