众所周知,共轭梯度法可以很好的对正定矩阵进行求解,但是在计算过程中我们往往难以得到正定矩阵,因此很多时候在使用共轭梯度法时难以保证矩阵为正定,那么此时我们依然可以使用共轭梯度法进行近似计算,得到一个还不错的结果,本文就使用共轭梯度法分别对正定矩阵和非正定矩阵两种形式进行对比:

参考:

正定矩阵的生成和判定

 

 

import numpy as np


def conjugate_gradient(A, b, cg_iters=10, residual_tol=1e-10):
    assert isinstance(A, np.ndarray)
    assert isinstance(b, np.ndarray)

    r = np.copy(b)
    p = np.copy(b)

    x = np.zeros_like(b)
    rdotr = np.dot(r, r)

    for i in range(cg_iters):
        z = np.dot(A, p)
        v = rdotr / np.dot(p, z)
        x += v * p
        r -= v * z
        newrdotr = np.dot(r, r)
        mu = newrdotr / rdotr
        p = r + mu * p

        rdotr = newrdotr
        print(i, rdotr, np.sqrt(np.mean(np.square(np.dot(A, x)-b))))
        if rdotr < residual_tol:
            break

    return x


if __name__ == '__main__':
    n = 100

    # 非正定矩阵的生成,A
    M = np.random.rand(n, n)
    A = np.dot(M.T, M) + 0.01*np.eye(n)

    # 正定矩阵的生成, A_
    from scipy.linalg import orth
    X = np.diag(np.abs(np.random.rand(n))+1)
    U = orth(np.random.rand(n, n))
    A_ = np.dot(np.dot(U.T, X), U)

    b = np.random.rand(n)


    print("="*30)
    print("使用非正定矩阵进行计算时:")
    x = conjugate_gradient(A, b, cg_iters=n)

    x_ = np.dot(np.linalg.inv(A), b)

    print("conjugate_gradient:")
    # print(x)
    print("residual error:")
    print(np.sqrt(np.mean(np.square(np.dot(A, x)-b))))
    print()
    print("np.linalg.inv:")
    # print(x_)
    print("residual error:")
    print(np.sqrt(np.mean(np.square(np.dot(A, x_)-b))))
    print()
    print("cos distance between cg vector and inv vector:")
    print(np.sqrt(np.mean(np.square(x - x_))))


    print("="*30)
    print("使用正定矩阵进行计算时:")
    x = conjugate_gradient(A_, b, cg_iters=n)

    x_ = np.dot(np.linalg.inv(A_), b)

    print("conjugate_gradient:")
    # print(x)
    print("residual error:")
    print(np.sqrt(np.mean(np.square(np.dot(A_, x)-b))))
    print()
    print("np.linalg.inv:")
    # print(x_)
    print("residual error:")
    print(np.sqrt(np.mean(np.square(np.dot(A_, x_)-b))))
    print()
    print("cos distance between cg vector and inv vector:")
    print(np.sqrt(np.mean(np.square(x - x_))))

 

 

 

=======================================================

 

 

 

通过本文代码,可以知道共轭梯度法在对非正定矩阵进行求解时依然可以获得比较好的性能,因此在很多机器学习问题中依然会见到共轭梯度法的使用。

 

 

共轭梯度法对于正定矩阵可以求得精确解,但是对于非正定矩阵,共轭梯度法依然可以求得精度较好的近似解。

 

 

运行效果:

 

==============================
使用非正定矩阵进行计算时:
0 9.864024055185642 0.3140704388379403
1 8.19449719162157 0.28626032193829387
2 7.580377547708742 0.2753248544484999
3 6.943800402568684 0.26351091822861294
4 4.438405253017757 0.21067522998724234
5 5.994202958711988 0.2448306140725046
6 6.076141568816932 0.24649830767810327
7 418.67772803326017 2.046161596827727
8 6.41049971547943 0.25318964661848775
9 8.652749467971416 0.29415556204109905
10 8.914283982002951 0.29856798190701933
11 11.60044435107292 0.3405942505544197
12 10.439038468043282 0.3230950087519644
13 11.785730308533108 0.3433035145251641
14 884.6301291299438 2.9742732374984753
15 14.37618134559932 0.3791593510069182
16 12.991809878592885 0.36044153310340415
17 12.067040282850645 0.34737645692894725
18 7.039741581000813 0.26532511341749715
19 11.30366077034768 0.33620917254512134
20 9.397934891346498 0.3065605142764803
21 56.319317628252165 0.7504619752409616
22 5.476506150847678 0.23401936139661011
23 3.2735889970562226 0.18093062198136534
24 3.029739945924992 0.17406148183688264
25 2.9782980642516526 0.172577462730551
26 2.0514817327404278 0.1432299456377903
27 9.787381055141552 0.3128479032236229
28 3.1250952874061237 0.17677939041091564
29 1.8858433436721413 0.1373260115080933
30 2.5799022482767193 0.16062074113504254
31 2.8699490882763907 0.16940924084229644
32 2.2468985088098314 0.14989658130890107
33 2.389089891496437 0.1545668105220773
34 138.50155124172267 1.1768668201701247
35 1.8753366379069436 0.13694293110295652
36 1.2073367687248968 0.10987887734794928
37 0.9515548869748356 0.09754767485567717
38 1.2125476044250334 0.11011573931208722
39 1.3696162423916702 0.11703060464644806
40 2.5215197686240947 0.15879293966119318
41 270.1426450487437 1.6436016702616643
42 2.0558278208994243 0.1433815825306467
43 1.57053397317861 0.12532094689949577
44 1.1288277769734387 0.10624630708751567
45 1.266917718046908 0.11255743947190049
46 0.8525437909905479 0.09233329794774761
47 1.0604821044682606 0.10297971181102711
48 21.3628331674485 0.46219945010192176
49 0.8433368331841224 0.09183337264763657
50 0.4278724440703318 0.06541195946234143
51 0.7683347291129197 0.08765470490014537
52 0.8746805780264646 0.09352435928816642
53 0.9480793861874427 0.09736936819079435
54 0.5416094833066896 0.07359412227255001
55 55.76939071140936 0.7467890646723568
56 0.3045570783102137 0.05518669027131416
57 0.2859118314835739 0.05347072390417121
58 0.208780349449446 0.045692488381528056
59 0.2155419650037685 0.04642649728374243
60 0.15328199870102632 0.03915124502501427
61 2.855875188241129 0.1689933486336514
62 0.1786217793903217 0.04226366990579621
63 0.19589457683561218 0.04425997930813217
64 0.10487483521630442 0.032384384387588186
65 0.11151003379199599 0.033393118122158456
66 0.1209706435056663 0.03478083430651505
67 0.12449781690849723 0.03528424817232373
68 1.743583888598339 0.13204483665013328
69 0.10683082610222117 0.03268498525351579
70 0.1310578151047584 0.03620190811333657
71 0.0929517580542835 0.03048799075936344
72 0.09495982611426697 0.03081555226088977
73 0.06702785863901231 0.025889739017417424
74 9.398584311618656 0.3065711061337919
75 0.05742925523335701 0.023964401772925636
76 0.05299740869893356 0.023021166064954256
77 0.13014270463938418 0.036075296899594954
78 0.06795475880156644 0.026068133573703755
79 0.04970340308894039 0.02229426004354329
80 0.4938735970983811 0.07027614083755876
81 0.08566863302243753 0.0292692044684468
82 0.06382663940800216 0.025263934651598707
83 0.042416913357052805 0.020595366798630374
84 0.07743056889173888 0.027826348824734035
85 0.030852217549491558 0.01756479932977135
86 0.03802395871795754 0.019499733002761196
87 1.840982574129803 0.13568281299220739
88 0.02067238949036798 0.014377896052709546
89 0.01161712603008939 0.010778277241798414
90 0.012363094518062044 0.011118945326809923
91 0.013002359992210655 0.011402789129074395
92 0.008424347044390698 0.00917842418088266
93 0.10046995235476161 0.03169699549667079
94 0.007933087991491294 0.00890678841765921
95 0.00638190315627314 0.00798868146587107
96 0.010191073124752498 0.010095084509172791
97 0.012248935857751442 0.011067491069679537
98 0.00925814016357861 0.009621922969765863
99 0.004199959705165739 0.006480709610259408
conjugate_gradient:
residual error:
0.006480709610259408

np.linalg.inv:
residual error:
2.414250255471783e-13

cos distance between cg vector and inv vector:
0.06483253423109721
==============================
使用正定矩阵进行计算时:
0 1.1929422453039054 0.10922189548363942
1 0.038997683490785445 0.019747831144403025
2 0.0013700937974428788 0.003701477809528057
3 2.7723082766175714e-05 0.0005265271385804902
4 9.10187651859136e-07 9.540375526460825e-05
5 2.1097997366339014e-08 1.4525149695015782e-05
6 6.792902053758683e-10 2.606319637698923e-06
7 2.2417605184284453e-11 4.7347233482068875e-07
conjugate_gradient:
residual error:
4.7347233482068875e-07

np.linalg.inv:
residual error:
1.5931411783871734e-16

cos distance between cg vector and inv vector:
3.4428545978867504e-07