线段树+矩阵。。。。
我们可以把第i层跟第i+1层之间楼梯的通断性构造成一个2*2的通断性矩阵,1表示通,0表示不通。那么从第a层到第b层,就是将a到b-1的通断性矩阵连乘起来,然后将得到的答案矩阵上的每个元素加起来即为方案数。想到矩阵的乘法是满足结合律的,那么我们可以用线段树来维护矩阵的乘积。每次我们只会修改某一个楼梯的通断性,所以就只是简单的线段树单点更新,成段求乘积而已。
整体复杂度
2∗2∗2∗nlogn
#include <iostream>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 50005
#define maxm 2000005
#define eps 1e-10
#define mod 1000000007
#define INF 1e9
#define lowbit(x) (x&(-x))
#define mp make_pair
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid
#define rson o<<1 | 1, mid+1, R
typedef long long LL;
typedef unsigned long long ULL;
//typedef int LL;
using namespace std;
struct matrix
{
LL mat[3][3];
matrix operator * (const matrix& a) const {
matrix res;
for(int i = 1; i <= 2; i++)
for(int j = 1; j <= 2; j++) {
LL t = 0;
for(int k = 1; k <= 2; k++)
t = (t + mat[i][k] * a.mat[k][j]) % mod;
res.mat[i][j] = t;
}
return res;
}
}sum[maxn << 2], one;
int n, m;
void pushup(int o)
{
sum[o] = sum[ls] * sum[rs];
}
void build(int o, int L, int R)
{
if(L == R) {
sum[o].mat[1][1] = sum[o].mat[1][2] = 1;
sum[o].mat[2][1] = sum[o].mat[2][2] = 1;
return;
}
int mid = (L + R) >> 1;
build(lson);
build(rson);
pushup(o);
}
void updata(int o, int L, int R, int q, int a, int b)
{
if(L == R) {
sum[o].mat[a][b] ^= 1;
return;
}
int mid = (L + R) >> 1;
if(q <= mid) updata(lson, q, a, b);
else updata(rson, q, a, b);
pushup(o);
}
matrix query(int o, int L, int R, int ql, int qr)
{
if(ql <= L && qr >= R) return sum[o];
int mid = (L + R) >> 1;
matrix ans = one;
if(ql <= mid) ans = ans * query(lson, ql, qr);
if(qr > mid) ans = ans * query(rson, ql, qr);
return ans;
}
void work(void)
{
build(1, 1, n);
one.mat[1][1] = one.mat[2][2] = 1;
one.mat[1][2] = one.mat[2][1] = 0;
int a, b, c, k;
while(m--) {
scanf("%d", &k);
if(!k) {
scanf("%d%d", &a, &b);
matrix tt = query(1, 1, n, a, b-1);
LL ans = 0;
for(int i = 1; i <= 2; i++)
for(int j = 1; j <= 2; j++)
ans = (ans + tt.mat[i][j]) % mod;
printf("%I64d\n", ans);
}
else {
scanf("%d%d%d", &a, &b, &c);
updata(1, 1, n, a, b, c);
}
}
}
int main(void)
{
while(scanf("%d%d", &n, &m)!=EOF) work();
return 0;
}