title: 树状数组与线段树经典问题的python实现
date: 2020-03-26 22:13:26
categories: 算法
tags: [python, 树状数组与线段树]
树状数组
作用:单点修改,区间求和
时间复杂度:修改和查询的复杂度都是O(logN)
要点:



k:x在二进制位下面末尾连续0的个数
原理:
利用的负数的存储特性,负数是以补码存储的,对于整数运算 x&(-x)有
1. 当x为0时,即 0 & 0,结果为0;
2. 当x为奇数时,最后一个比特位为1,取反加1没有进位,故x和-x除最后一位外**前面的位正好相反,且最后一位均为1**。结果为1。5&(-5)=(101)&(011)=1
3. 当x为偶数时,取反时,末尾连续k个0均为变成1,加1时,往前进一位,正好是2^k次方。
则定C[x]=sum(x-lowbit(x),x](sum(l,r]表示数组a区间(l,r]的区间和)
区间求和有C[x]定义很明了。
对于单点增加:在图上树的过程,假设增加a[i],我们只需要修改其父亲和祖宗节点,例如增加a[5],我们需要修改C[5],C[6],C[8],C[16],可以证明x的父亲节点有且仅有一个为x+lowbit(x)
动态求连续区间和
给定 n 个数组成的一个数列,规定有两种操作,一是修改某个元素,二是求子数列 [a,b][a,b] 的连续和。
输入格式
第一行包含两个整数n 和 m,分别表示数的个数和操作次数。
第二行包含 n 个整数,表示完整数列。
接下来 m 行,每行包含三个整数 k,k,a,b (k=0,表示求子数列[a,b]的和;k=1,表示第 a 个数加 b)。
数列从 11 开始计数。
输出格式
输出若干行数字,表示 k=0 时,对应的子数列 [a,b] 的连续和。
数据范围
1≤n≤100000
1≤m≤100000,
1≤a≤b≤n
输入样例:
10 5
1 2 3 4 5 6 7 8 9 10
1 1 5
0 1 3
0 4 8
1 7 5
0 4 8
输出样例:
挑战模式
n,m=map(int,input().split())
a=[0 for i in range(0,n+1)]
tr=[0 for i in range(0,n+1)]
def lowbit(x):
return x&(-x)
def query(x):
res=0
while x:
res+=tr[x]
x-=lowbit(x)
return res
def add(x,val):
while x<=n:
tr[x]+=val
x+=lowbit(x)
a=list(map(int,input().split()))
for i in range(0,n):
add(i+1,a[i])
for t in range(0,m):
k,l,r=map(int,input().split())
if k==0:
print(query(r)-query(l-1))
else:
add(l,r)
线段树
本质是二叉树,分治思想。
u的左儿子u<<1(2u)和右儿子u<<1|1(2u+1)
数列区间最大值
输入一串数字,给你 M 个询问,每次询问就给你两个数字 X,Y,要求你说出 X到 Y 这段区间内的最大数。
输入格式
第一行两个整数 N,M 表示数字的个数和要询问的次数;
接下来一行为 N 个数;
接下来 MM 行,每行都有两个整数 X,Y。
输出格式
输出共 M行,每行输出一个数。
数据范围
1≤N≤105,
1≤M≤106,
1≤X≤Y≤N,
数列中的数字均不超过2^31−1
输入样例:
10 2
3 2 4 5 6 8 1 2 9 7
1 4
3 8
输出样例:
class Node:
def __init__(self,l=0,r=0,maxx=0):
self.l=l
self.r=r
self.maxx=maxx
n,m=map(int,input().split())
tr=[Node() for i in range(0,100005*4+100)]
def pushup(u):
tr[u].maxx=max(tr[u<<1].maxx, tr[u<<1|1].maxx)
def build(u,l,r):
if l == r:
tr[u] = Node(l, r, w[r-1])
else:
tr[u]=Node(l,r)
mid=(l+r) >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid+1, r)
pushup(u)
def query_max(u, l, r):
if tr[u].l >= l and tr[u].r <= r:
return tr[u].maxx
mid=(tr[u].l + tr[u].r) >> 1
maxx=-1e18
if l <= mid : maxx=max(query_max(u << 1, l, r), maxx)
if r > mid : maxx=max(query_max(u << 1 | 1, l, r), maxx)
return maxx
w=list(map(int,input().split()))
build(1, 1, n)
for i in range(0,m):
x,y=map(int,input().split())
print(query_max(1, x, y))