1.大整数乘法

由于python语言可以实现任意精度的乘法,故这里采用python语言实现常规算法与分治算法的时间效率。结果如下图:

python大整数相加 python大整数乘法_python大整数相加


常规算法与分治算法的时间效率

横轴表示相乘两数的位数,纵轴表示常规算法与分治算法分别所用的时间。可以看到,常规算法的时间效率虽然偶尔有些小幅度的波动,但是基本上呈指数增长的趋势。而分治算法的时间效率随着位数的增加,其波动幅度在增大,但是整体趋势却没有出现明显增长的状况。从整体上而言,相乘两数的位数在60左右时,其时间效率大抵相当。

python代码如下

from timeit import timeit
import numpy as np
import pandas as pd
np.random.seed(0)

def fun(n):
    num=0
    while(n>0):
        num=num+1
        n=int(n/10)
    return num

def fun2(a,b):
    if(fun(a)<2 or fun(b)<2):
        return a*b
    else:
        n=int(fun(a)/2)
        a1=int(a/pow(10,n))
        a0=a-a1*pow(10,n)
        b1 = int(b / pow(10, n))
        b0 = b - b1 * pow(10, n)
        c2 = fun2(a1,b1)
        c0 = fun2(a0,b0)
        c1 = fun2(a1+a0,b1+b0) - (c2+c0)
        return c2*pow(10,2*n)+c1*pow(10,n)+c0

def fun3(a,b):
    aa=list(map(int,reversed(str(a))))
    bb=list(map(int,reversed(str(b))))
    result = [0]*(len(aa)+len(bb))

    for ia,va in enumerate(aa):
        c=0
        for ib,vb in enumerate(bb):
            c,result[ia+ib] = divmod(va*vb+c+result[ia+ib],10)
        result[ia+len(bb)] = c
    return int(''.join(map(str,reversed(result))))

numbers = [pow(10,i) for i in range(1,201)]
t0,t1=[],[]
for num in numbers:
    a = int(np.random.rand()*num)
    b = int(np.random.rand() * num)
    t0.append(timeit(stmt="fun3(a,b)",setup ='from __main__ import fun3,a,b',number=1))
    t1.append(timeit(stmt="fun2(a,b)",setup ='from __main__ import fun2,a,b',number=1))
result = pd.DataFrame({"位数":[i for i in range(1,201)],"常规算法":t0,"分治算法":t1})
result.to_csv('result.csv',index=False)

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
plt.plot(list(range(len(numbers))),t0, linestyle='--',label='常规算法')
plt.plot(list(range(len(numbers))),t1, linestyle='-',label='分治算法')
plt.legend()
plt.xlabel('位数')
plt.ylabel('所需要的时间/秒')
plt.show()

2.Strassen矩阵乘法

运用python编程,比较常规算法和Strassen算法的时间效率,结果如下图:

python大整数相加 python大整数乘法_sed_02


两种算法与输入矩阵阶数的变化趋势横坐标表示矩阵2的n次方阶数,纵坐标表示所需时间。可以看到,虽然随着阶数的增加,常规算法与Strassen算法都在增加,但是幅度却不一样。 Strassen算法的增加幅度明显大于常规算法。当矩阵阶数从pow(2,0) 到 pow(2,5)时,两者的时间效率大致相同。但是当阶数从pow(2,6) 增加时,Strassen 算法所需时间开始大于常规算法。这与上述理论不符,通过查阅文献[4],得到两点解释:(1)采用 Strassen算法做递归运算,需要创建大量的动态二维数组,其中分配堆内存空间将占用大量计算时间,从而掩盖了Strassen 算法的优势。(2)当矩阵很稠密,并且阶数非常大时,才会优先考虑使用Strassen 算法。

所以有必要对 算法进行改进。可以设定一个阶数的界限n1 ,当 n<n1时,使用常规算法,当 n>=n1时使用Strassen算法。也就是说充分运用两个算法的优势,结合使用,使得矩阵乘法的时间效率最高。

下图表示结合使用两种算法计算阶数为pow(2,5)到pow(2,9)矩阵的时间效率(当阶数不大于pow(2,6)时使用常规算法,否则使用Strassen算法)。由此可见,当矩阵阶数较大时,Strassen 算法确实要比常规算法所需时间更少,时间效率更高。

python大整数相加 python大整数乘法_算法设计_03


改进算法随输入矩阵阶数的变化状况

python代码如下

def matrixMultiplication(a,b):
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    for i in range(n):
        for j in range(n):
            c[i][j]=0
            for k in range(n):
                c[i][j] += a[i][k]*b[k][j]
    return c

def strassen(a, b):
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if n == 1:
        c[0][0] = a[0][0] * b[0][0]
    else:
        (a11, a12, a21, a22) = division(a)
        (b11, b12, b21, b22) = division(b)
        # (c11, c12, c21, c22) = division(c)
        s1 = add_sub(b12, b22, 0)
        s2 = add_sub(a11, a12, 1)
        s3 = add_sub(a21, a22, 1)
        s4 = add_sub(b21, b11, 0)
        s5 = add_sub(a11, a22, 1)
        s6 = add_sub(b11, b22, 1)
        s7 = add_sub(a12, a22, 0)
        s8 = add_sub(b21, b22, 1)
        s9 = add_sub(a11, a21, 0)
        s10 = add_sub(b11, b12, 1)

        p1 = strassen(a11, s1)
        p2 = strassen(s2, b22)
        p3 = strassen(s3, b11)
        p4 = strassen(a22, s4)
        p5 = strassen(s5, s6)
        p6 = strassen(s7, s8)
        p7 = strassen(s9, s10)

        c11 = add_sub(add_sub(add_sub(p5, p4, 1), p2, 0), p6, 1)
        c12 = add_sub(p1, p2, 1)
        c21 = add_sub(p3, p4, 1)
        c22 = add_sub(add_sub(add_sub(p5, p1, 1), p3, 0), p7, 0)
        c = combination(c11, c12, c21, c22)
    return c

def division(a):  # 对矩阵进行分解操作
    n = len(a) // 2
    a11 = [[0 for i in range(n)] for j in range(n)]
    a12 = [[0 for i in range(n)] for j in range(n)]
    a21 = [[0 for i in range(n)] for j in range(n)]
    a22 = [[0 for i in range(n)] for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j] = a[i][j]
            a12[i][j] = a[i][j + n]
            a21[i][j] = a[i + n][j]
            a22[i][j] = a[i + n][j + n]
    return (a11, a12, a21, a22)

def add_sub(a, b, keys):
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if keys == 1:
        for i in range(n):
            for j in range(n):
                c[i][j] = a[i][j] + b[i][j]
    else:
        for i in range(n):
            for j in range(n):
                c[i][j] = a[i][j] - b[i][j]
    return c

def combination(a11, a12, a21, a22):
    n2 = len(a11)
    n = n2 * 2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range(0, n):
        for j in range(0, n):
            if i <= (n2 - 1) and j <= (n2 - 1):
                a[i][j] = a11[i][j]
            elif i <= (n2 - 1) and j > (n2 - 1):
                a[i][j] = a12[i][j - n2]
            elif i > (n2 - 1) and j <= (n2 - 1):
                a[i][j] = a21[i - n2][j]
            else:
                a[i][j] = a22[i - n2][j - n2]
    return a
import numpy as np
import pandas as pd
from timeit import timeit
numbers = [pow(2,i) for i in range(8)]
t0,t1=[],[]
for num in numbers:
    a = np.random.randint(10,size=(num,num))
    b = np.random.randint(10,size=(num,num))
    t0.append(timeit(stmt="matrixMultiplication(a,b)",setup ='from __main__ import matrixMultiplication,a,b',number=1))
    t1.append(timeit(stmt="strassen(a,b)",setup ='from __main__ import strassen,a,b',number=1))

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

plt.plot(list(range(len(numbers))),t0,linestyle='--',label='常规算法')
plt.plot(list(range(len(numbers))),t1,linestyle='-',label='Strassen算法')
plt.legend()
plt.xlabel('阶数')
plt.ylabel('所需要的时间/秒')
plt.show()