hadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM4NjQxOTg1,size_16,color_FFFFFF,t_70)
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
from sklearn import linear_model
import random
def create_data():
X=np.array(
[
[1,3,3],
[1,4,3],
[1,1,1],
[1,0,2]
])
Y= np.array(
[
[1],
[1],
[-1],
[-1]
] )
return X,Y
def update_ws(X,Y,W,lr=0.01):
V= np.sign(np.dot(X,W))#np.dot(X,W) array矩阵相乘 4*3乘3*1等于4*1
print (Y)
W_ =lr*(X.T.dot(Y-V))/int(X.shape[0])#np.dot(X,W) array矩阵相乘 3*4乘4*1等于3*1
W = W+W_
return W
def main():
x_,y_=create_data()
W= (np.random.random([3,1])-0.5)*2
for i in range(100):
W= update_ws(x_,y_,W,0.11)
V_ = np.sign(np.dot(x_,W))
if(V_ == y_).all():
print ("epochs = {}".format(i+1));
break;
x1 = [3,4]
y1 = [3,3]
x2 = [1,0]
y2 = [1,2]
k= -W[1]/W[2]
b= -W[0]/W[2]
print ("k = {},b= {}".format(k,b));
x_axis =(0,5)
plt.plot(x_axis,x_axis*k+b,'r')
# plt.plot(-1,1,c='g',marker="o")
plt.scatter(x1,y1,c='b')
plt.scatter(x2,y2,c='g')
plt.show()
main()
随机生成范围数据
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
from sklearn import linear_model
import random
def create_data_():#随机创建10个点,这样随机的点未必正确,最好读取数据集,或创建一条直线,分出边界。
n=5
x1=(np.random.randint(-5,0,size=(n,3)))
x2=(np.random.randint(0,5,size=(n,3)))
y1= (np.random.randint(-1,0,size=[n,1]))
y2=(np.random.randint(1,2,size=[n,1]))
Y =np.concatenate((y1,y2),axis=0)#水平组合
X=np.concatenate((x1,x2),axis=0)#水平组合
return X,Y
def update_ws(X,Y,W,lr=0.01):
V= np.sign(np.dot(X,W))#np.dot(X,W) array矩阵相乘 10*3乘3*1等于10*1
W_ =lr*(X.T.dot(Y-V))/int(X.shape[0])#np.dot(X,W) array矩阵相乘 3*10乘10*1等于3*1
W = W+W_
return W
def main():
x_,y_=create_data_()
W= (np.random.random([3,1])-0.5)*2
for i in range(1000):
W= update_ws(x_,y_,W,0.02)
V_ = np.sign(np.dot(x_,W))
if(V_ == y_).all():
print ("epochs = {}".format(i+1));
#break;
print ("W = {}.".format(W));
k= -W[1]/W[2]
b= -W[0]/W[2]
print ("k = {},b= {}".format(k,b));
x_axis =(-5,5)
count = (len(x_))
for i in range(count):
if i<5:
plt.plot(x_[i][0],x_[i][1],c='b',marker="o")
else:
plt.plot(x_[i][0],x_[i][1],c='g',marker="o")
plt.plot(x_axis,x_axis*k+b,'r')
plt.show()
main()