[python刷题模板] 一些数论工具

  • 一、 算法&数据结构
  • 1. 描述
  • 2. 复杂度分析
  • 3. 常见应用
  • 4. 常用优化
  • 二、 模板代码
  • 1. 不用浮点数求向上取整公式 ceil(a/b) = (a+b-1)//b
  • 2. 快速幂pow(m,x,MOD)
  • 3. 组合数C(m,r) = comb(m,r)
  • 4. 排列数P(m,R) = perm(m,r)
  • 4.1 含重复元素的全排列种类数p = C(n,c1) * C(n-c1,c2) * C(n-c1-c2,c3) * ... * C(cx,cx)
  • 5. 最大公约数gcd = math.gcd(a,b)
  • 6. 最小公倍数lcm = math.lcm(a,b)
  • 7. 质数筛法:埃氏筛O(nlgn)、线性筛O(n)(欧拉筛)
  • 8. 求一个数所有质因子以及个数。
  • 三、其他
  • 四、更多例题
  • 五、参考链接


一、 算法&数据结构

1. 描述

数论太难了,有的思路第二次写又要重新想,这里记录一些数学相关常用方法,备查。

2. 复杂度分析

  1. 挨个分析。

3. 常见应用

  1. 组合数学、取模…
  2. 计算几何

4. 常用优化

  1. 记忆化。

二、 模板代码

1. 不用浮点数求向上取整公式 ceil(a/b) = (a+b-1)//b

涉及到取整,如果直接ceil就有浮点数精度问题,问就是wa过。

# ceil(a/b) = (a+b-1)//b
s = (a+b-1)//b

2. 快速幂pow(m,x,MOD)

这个我写过一篇[python刷题模板] 取模技巧/组合数取模模板/快速幂取模模板 但是py其实pow函数第三个参数支持模。。。
或者由于py支持大数,其实可以直接m**x%MOD

ans = pow(m,x,MOD)
ans = m**x%MOD

3. 组合数C(m,r) = comb(m,r)

py自带组合数方法,支持大数运算,别忘了外边取模

c = comb(m,r)
class ModComb:
    def __init__(self, n, p):
        """
        初始化,为了防止模不一样,因此不写默认值,强制要求调用者明示
        :param n:最大值,通常是2*(10**5)+50
        :param p: 模,通常是10**9+7
        """
        self.p = p
        self.inv_f, self.fact = [1] * (n + 1), [1] * (n + 1)  # 阶乘的逆元、阶乘
        inv_f, fact = self.inv_f, self.fact
        for i in range(2, n + 1):
            fact[i] = i * fact[i - 1] % p
        inv_f[-1] = pow(fact[-1], p - 2, p)
        for i in range(n, 0, -1):
            inv_f[i - 1] = i * inv_f[i] % p

    def comb(self, m, r):
        if m < r or r < 0:
            return 0
        return self.fact[m] * self.inv_f[r] % self.p * self.inv_f[m - r] % self.p

    def perm_count_with_duplicate(self, a):
        """含重复元素的列表a,全排列的种类。
        假设长度n,含x种元素,分别计数为[c1,c2,c3..cx]
        则答案是C(n,c1)*C(n-c1,c2)*C(n-c1-c2,c3)*...*C(cx,cx)
        或:n!/c1!/c2!/c3!/../cn!
        """
        ans = self.fact[len(a)]
        for c in Counter(a).values():
            ans = ans * self.inv_f[c] % self.p           
        return ans 
        # 下边这种也可以
        # s = len(a)
        # ans = 1
        # for c in Counter(a).values():
        #     ans = ans * self.comb(s,c) % MOD 
        #     s -= c
        # return ans

4. 排列数P(m,R) = perm(m,r)

perm如果不传入第二个参数,就是阶乘

p = perm(m,r)
perm(n) == factorial(n)

4.1 含重复元素的全排列种类数p = C(n,c1) * C(n-c1,c2) * C(n-c1-c2,c3) * … * C(cx,cx)

6276. 统计同位异构字符串数目

MOD = 10  ** 9 + 7
class Solution:
    def countAnagrams(self, s: str) -> int:
        ret = 1
        def perm_count_with_duplicate(a):
            """含重复元素的列表a,全排列的种类。
            假设长度n,含x种元素,分别计数为[c1,c2,c3..cx]
            则答案是C(n,c1)*C(n-c1,c2)*C(n-c1-c2,c3)*...*C(cx,cx)
            """
            s = len(a)
            ans = 1
            for c in Counter(a).values():
                ans = ans * comb(s,c) % MOD 
                s -= c
            return ans 
      

        for w in s.split():
            ret = ret * perm_count_with_duplicate(w) %MOD

        return (ret)%MOD

5. 最大公约数gcd = math.gcd(a,b)

注意这个gcd支持传入多参数,有两种写法,建议用星号,因为reduce如果a是空数组会报错。
注意gcd(a,0)=a,即任意数和0的gcd都是自己,参照循环相减法。

ans = gcd(a,b)
a = [3,5,10]
print(gcd(*a))
print(reduce(gcd,a))

6. 最小公倍数lcm = math.lcm(a,b)

写法和gcd类似。

ans = lcm(a,b)
a = [3,5,10]
print(lcm(*a))
print(reduce(lcm,a))

7. 质数筛法:埃氏筛O(nlgn)、线性筛O(n)(欧拉筛)

  • 由于py切片优化的原因,埃氏筛的表现超过线性筛
    欧拉筛
def tag_primes_euler(n):  # 返回一个长度n+1的数组p,如果i是质数则p[i]=1否则p[i]=0
    primes = [1]*(n+1)
    primes[0] = primes[1] = 0  # 0和1不是质数
    ps = []  # 记质数
    for i in range(2,n+1):
        if primes[i]:
            ps.append(i)
        for j in ps:
            if j*i>n:
                break
            primes[j*i] = 0
            if i%j == 0:break
    # print(ps)
    return primes
            
primes = tag_primes_euler(10**5+5)

埃氏筛给所有素数标记1

def tag_primes_eratosthenes(n):  # 返回一个长度n的数组p,如果i是质数则p[i]=1否则p[i]=0
    primes = [1]*n
    primes[0] = primes[1] = 0  # 0和1不是质数
    for i in range(2,int(n**0.5)+1):
        if primes[i]:
            primes[i * i::i] = [0] * ((n - 1 - i * i) // i + 1)
    return primes
primes = tag_primes_eratosthenes(5*10**5+5)

8. 求一个数所有质因子以及个数。

  • 分解质因数可以用用线性筛处理,达到很优。
  • 如果是较大的数如x=1e8之类的数分解质因数,可以先处理sqrt(n)以内的质数,在质数上枚举,而不是在2~sqrt(n)枚举,这样复杂度会除10左右。(x超过1e5时,x以内的质数数量<x/10,且越来越小)
  • 遇到一题,n=1e5,a[i]=1e8,如果直接枚举sqrt(U),复杂度就是1e9,但在质数上枚举,复杂度就是1e8,能过。星石传送阵【算法赛】

def get_prime_reasons(x):
    # 获取x的所有质因数,虽然是两层循环且没有判断合数,但复杂度依然是O(sqrt(x))
    # 由于i是从2开始增加,每次都除完,因此所有合数的因数会提前除完,合数不会被x整除的
    if x == 1:
        return Counter()
    ans = Counter()
    i = 2
    while i*i<=x:
        while x%i==0:
            ans[i] += 1
            x //= i 
        i += 1
    if x > 1: ans[x] += 1
    return ans

class Solution:
    def smallestValue(self, n: int) -> int:
        while n:
            cnt = get_prime_reasons(n)
            s = 0
            for k,v in cnt.items():
                s += k*v
            if s == n:return n
            n = s
class Solution:
    def smallestValue(self, n: int) -> int:
        while True:
            x, s, i = n, 0, 2
            while i * i <= x:
                while x % i == 0:
                    s += i
                    x //= i
                i += 1
            if x > 1: s += x
            if s == n: return n
            n = s

多次使用,先搞个质数筛加快运行;然后记忆化。

def tag_primes_euler(n):  # 返回一个长度n+1的数组p,如果i是质数则p[i]=1否则p[i]=0
    primes = [1]*(n+1)
    primes[0] = primes[1] = 0  # 0和1不是质数
    ps = []  # 记质数
    for i in range(2,n+1):
        if primes[i]:
            ps.append(i)
        for j in ps:
            if j*i>n:
                break
            primes[j*i] = 0
            if i%j == 0:break
    # print(ps)
    return primes
            
primes = tag_primes_euler(10**5+5)


@cache
def get_prime_reasons(x):
    if x == 1:
        return Counter()
    if primes[x]:
        return Counter([x])
    for i in range(2,int(x**0.5)+1):
        if x % i == 0:
            return get_prime_reasons(i) + get_prime_reasons(x//i)

三、其他



四、更多例题



五、参考链接