题解:P4396 [AHOI2013] 作业

P4396 [AHOI2013] 作业 - 洛谷

这就是一个莫队题, 但是在分块题单内.

问题 1: 区间 $[l,r]$ 内, 值属于 $[a,b]$ 的元素个数

首先有一个分块想法, 把元素按照块放在 vector 里, 给每个块排序. 查询时整块 upper_bound(ve[i].begin(), ve[i].end(), b) - lower_bound(ve[i].begin(), ve[i].end(), a), 散块暴力枚举. 但这样太慢了.

设 $sum_{i,v}$ 表示第 $i$ 个块中值 $\le v$ 的数量. 那么这个就可以看成是一个二维前缀和, 可以 $O(1)$ 查询整块答案.两端散块再暴力扫描即可.

预处理复杂度 $O(nV)$, 其中 $V$ 为值域($10^5$), 查询变成 $O(\sqrt n)$.

问题 2: 区间 $[l,r]$ 内, 值属于 $[a,b]$ 的不同元素个数

我们维护两个核心预处理数组:

  1. short f[i][j][k]: 表示序列第 $i$ 块到第 $j$ 块中, 值域属于第 $k$ 个值域块的不同元素个数.
    预处理时, 固定起点块 $i$, 向后扫描序列, 用 vis 去重即可. 由于答案不会超过块长, 因此可以使用 short 节省空间.

  2. int fd[i][v]: 表示数值 $v$ 在序列第 $i$ 块及之后第一次出现的位置.
    倒序扫描原序列即可递推得到.

查询时, 若 $(l,r)$ 在同一块内, 直接暴力扫描并去重. 否则设中间完整序列块为 $[L+1,R-1]$, 答案分为三部分.

Part A (中间序列块 + 完整值域块)

对于值域中间的完整块, 直接利用预处理:

1
cnt += f[L+1][R-1][k];

Part C (序列散块)

遍历序列左右散块, 利用 vis 去重.

对于每个满足 $a \le v \le b$ 的值:

  • 若属于值域散块, 仅标记 vis, 留待 Part B 统计.
  • 若属于完整值域块, 且 $fd[L + 1][v] > (R - 1) \times S$, 说明它没有在中间完整序列块中出现, 应贡献答案.

Part B (值域散块)

最后枚举值域左右两个散块中的每个值.

  • vis[v] 为真, 说明它已在序列散块中出现.
  • 否则判断 $fd[L + 1][v] \le (R - 1) \times S$, 若成立, 则说明它在中间完整序列块中出现.

一个值满足以上两种情况之一, 即可贡献一种不同元素.

参考代码

需要略微卡下常, 这里不过多说.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <algorithm>
#include <cstring>
#include <iostream>
#include <set>
#include <utility>
#include <vector>
using namespace std;
const int N = 1e5 + 5, S = 316;
int arr[N], bl[N];
short f[S + 30][S + 30][S + 30];
int lst[N], fd[S + 30][N];
int sum[S + 30][N];

bool vis[N];

int ch[N];
int top = 0;
pair<int, int> solve(int l, int r, int a, int b) {
int num = 0, cnt = 0;
if (bl[l] == bl[r]) {
for (int i = l; i <= r; ++i) {
if (a <= arr[i] && arr[i] <= b) {
num++;
vis[arr[i]] = 1;
ch[++top] = arr[i];
}
}
for (int i = 1; i <= top; i++) {
if (vis[ch[i]]) cnt++;
vis[ch[i]] = 0;
}
top = 0;
return make_pair(num, cnt);
}
if (bl[l] + 1 <= bl[r] - 1) {
num += (sum[bl[r] - 1][b] - sum[bl[l]][b]) - (sum[bl[r] - 1][a - 1] - sum[bl[l]][a - 1]);

for (int i = bl[a] + 1; i <= bl[b] - 1; ++i) cnt += f[bl[l] + 1][bl[r] - 1][i];
}

for (int i = l; bl[l] == bl[i]; ++i) {
int v = arr[i];
if (a <= v && v <= b) {
num++;
if (!vis[v]) {
vis[v] = 1;
ch[++top] = v;
if (bl[a] < bl[v] && bl[v] < bl[b]) {
if (bl[r] - bl[l] <= 1 || fd[bl[l] + 1][v] > (bl[r] - 1) * S) {
cnt++;
}
}
}
}
}
for (int i = r; bl[r] == bl[i]; --i) {
int v = arr[i];
if (a <= v && v <= b) {
num++;
if (!vis[v]) {
vis[v] = 1;
ch[++top] = v;
if (bl[a] < bl[v] && bl[v] < bl[b]) {
if (bl[r] - bl[l] <= 1 || fd[bl[l] + 1][v] > (bl[r] - 1) * S) {
cnt++;
}
}
}
}
}
for (int v = a; v <= b && bl[v] == bl[a]; ++v) {
if (vis[v]) {
cnt++;
} else if (bl[r] - bl[l] >= 2 && fd[bl[l] + 1][v] <= (bl[r] - 1) * S) {
cnt++;
}
}
if (bl[a] != bl[b]) {
for (int v = b; v >= a && bl[v] == bl[b]; --v) {
if (vis[v]) {
cnt++;
} else if (bl[r] - bl[l] >= 2 && fd[bl[l] + 1][v] <= (bl[r] - 1) * S) {
cnt++;
}
}
}
for (int i = 1; i <= top; i++) {
vis[ch[i]] = 0;
}
top = 0;
return make_pair(num, cnt);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= N - 5; ++i) {
bl[i] = (i - 1) / S + 1;
}
for (int i = 1; i <= n; i++) {
cin >> arr[i];
sum[bl[i]][arr[i]]++;
}
for (int i = 1; i <= bl[n]; i++)
for (int v = 1; v <= N - 5; v++) sum[i][v] += sum[i][v - 1];
for (int i = 1; i <= bl[n]; i++)
for (int v = 1; v <= N - 1; v++) sum[i][v] += sum[i - 1][v];
for (int i = 1; i <= bl[n]; ++i) {
// memset(vis, 0, sizeof(vis));
for (int j = (i - 1) * S + 1; j <= n; j++) {
if (!vis[arr[j]]) {
f[i][bl[j]][bl[arr[j]]]++;
vis[arr[j]] = 1;
ch[++top] = arr[j];
}
}
for (int j = i + 1; j <= bl[n]; j++) {
for (int k = 1; k <= bl[N - 5]; k++) {
f[i][j][k] += f[i][j - 1][k];
}
}
for (int i = 1; i <= top; i++) vis[ch[i]] = 0;
top = 0;
}
// memset(vis, 0, sizeof(vis));

// memset(lst, 0x3f, sizeof(lst));
for (int i = 1; i < N; i++) {
lst[i] = 0x3f3f3f3f;
}
for (int i = n; i >= 1; --i) {
lst[arr[i]] = i;
if (i % S == 1) {
for (int j = 1; j <= N - 5; j++) {
fd[bl[i]][j] = lst[j];
}
}
}

while (m--) {
int l, r, a, b;
cin >> l >> r >> a >> b;
auto res = solve(l, r, a, b);
cout << res.first << " " << res.second << "\n";
}
return 0;
}