把一个矩阵化成3个三角形容斥,然后用等差线段树就可以做了...
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define now o, L, R, tree
#define lson o << 1, L, mid, tree
#define rson o << 1 | 1, mid+1, R, tree
#define ls o << 1
#define rs o << 1 | 1
const int maxn = 400005;
struct node
{
LL sum, val, lazy;
node(LL sum = 0, LL val = 0, LL lazy = 0) : sum(sum), val(val), lazy(lazy) {}
}a[maxn << 2], b[maxn << 2];
int n, m;
void pushdown(int o, int L, int R, node tree[])
{
if(tree[o].lazy) {
int mid = (L + R) >> 1;
LL tl = mid-L+1, tr = R-mid;
tree[ls].sum += (1 + tl) * tl / 2 * tree[o].lazy;
tree[rs].sum += (1 + tr) * tr / 2 * tree[o].lazy;
tree[ls].val += tl * tree[o].lazy;
tree[rs].val += tr * tree[o].lazy;
tree[ls].lazy += tree[o].lazy;
tree[rs].lazy += tree[o].lazy;
tree[o].lazy = 0;
}
}
void pushup(int o, int L, int R, node tree[])
{
int mid = (L + R) >> 1;
tree[o].val = tree[ls].val + tree[rs].val;
tree[o].sum = tree[ls].sum + tree[rs].sum + tree[rs].val * (mid-L+1);
}
void build(int o, int L, int R, node tree[])
{
tree[o].sum = tree[o].val = tree[o].lazy = 0;
if(L == R) return;
int mid = (L + R) >> 1;
build(lson);
build(rson);
pushup(now);
}
void update(int o, int L, int R, node tree[], int ql, int qr)
{
if(ql <= L && qr >= R) {
LL len = R - L + 1;
tree[o].lazy += 1;
tree[o].val += len;
tree[o].sum += (1 + len) * len / 2;
return;
}
pushdown(now);
int mid = (L + R) >> 1;
if(ql <= mid) update(lson, ql, qr);
if(qr > mid) update(rson, ql, qr);
pushup(now);
}
node query(int o, int L, int R, node tree[], int ql, int qr)
{
if(ql <= L && qr >= R) return tree[o];
pushdown(now);
int mid = (L + R) >> 1;
node ans;
if(qr <= mid) ans = query(lson, ql, qr);
else if(ql > mid) ans = query(rson, ql, qr);
else {
ans = query(lson, ql, qr);
node t = query(rson, ql, qr);
ans.val += t.val;
ans.sum += t.sum;
ans.sum += (mid - max(ql, L) + 1) * t.val;
}
pushup(now);
return ans;
}
LL solve2(int ql, int qr)
{
ql += n, qr += n;
node ans = query(1, 1, 2 * n, b, ql, qr);
return ans.sum;
}
LL solve1(int ql, int qr)
{
node ans = query(1, 1, 2 * n, a, ql, qr);
return ans.sum;
}
void solve()
{
int x1, y1, x2, y2;
scanf("%d%d%d%d", &x1, &x2, &y1, &y2);
LL ans = 0;
ans += solve2(x1 - y2, x2 - y1);
ans -= solve2(x1 - y1 + 1, x2 - y1);
ans -= solve2(x2 - y2 + 1, x2 - y1);
ans += solve1(x1 + y1, x2 + y2);
ans -= solve1(x2 + y1 + 1, x2 + y2);
ans -= solve1(x1 + y2 + 1, x2 + y2);
printf("%lld\n", ans);
}
void work()
{
scanf("%d%d", &n, &m);
build(1, 1, 2 * n, a);
build(1, 1, 2 * n, b);
while(m--) {
int op;
scanf("%d", &op);
if(op == 1) {
int ql, qr;
scanf("%d%d", &ql, &qr);
update(1, 1, 2 * n, a, ql, qr);
}
if(op == 2) {
int ql, qr;
scanf("%d%d", &ql, &qr);
ql += n, qr += n;
update(1, 1, 2 * n, b, ql, qr);
}
if(op == 3) solve();
}
}
int main()
{
int _;
scanf("%d", &_);
for(int i = 1; i <= _; i++) {
printf("Case #%d:\n", i);
work();
}
return 0;
}