【题解】[WC2016]论战捆竹竿
题目链接:[WC2016]论战捆竹竿
题意:
给定长度为 $n$ 的字符串 $S$,现有一个空串 $T$,每次可将 $S$ 去掉一个 $\text{border}$ 后接在 $T$ 上。问 $T$ 的长度可以是 $[n,w]$ 中的多少个数。
$1\le n\le 5\times 10^5,1\le w\le 10^{18}$。
每次去掉一个 $\text{border}$ 加入,等价于加入一个 $\text{period}$。设 $S$ 的 $\text{period}$ 长度构成集合 $X$,取出 $X$ 中一个元素 $x$。
借用同余最短路的思想,设 $\text{mindis}_i$ 表示 $\bmod\ x=i$ 的所有长度中,能被 $X$ 中元素所拼出的最小长度。则能够被拼出的,长度 $\bmod\ x=i$ 的串长为 $\text{mindis}_i+k\times x$。可以通过该数组快速求出答案。
则现在需要做的是加速同余最短路的过程,快速求出 $\text{mindis}_i$。
考虑以下结论:
一个字符串 $S$ 的 $\text{border}$ 长度构成 $\mathcal{O}(\log n)$ 个等差数列。
而 $\text{period}$ 对应 $\text{border}$,所以也是等差数列。这启发我们将一个等差数列中的 $\text{border}$ 放在一起考虑。
对于一个等差数列 $a,a+b,a+2b\cdots a+kb$,若当前已经得到之前的等差数列所求出的,$\bmod\ a$ 意义下的 $\text{mindis}_i$,考虑如何快速用一个等差数列更新该数组。
在 $\bmod\ a$ 意义下,公差为 $b$ 的等差数列将所有 $[0,a-1]$ 中的元素划分为 $\gcd(a,b)$ 个等价类,每个等价类互不干扰。
对于一个等价类,找出其中 $\text{mindis}$ 值最小的位置 $p$,显然在这轮更新中,$\text{mindis}_p$ 不改变。则从 $p$ 开始,每次考虑使用等差数列中一个元素来更新后面的元素。
具体来说,对于一个下标为 $q$ 的位置,若 $q-p\le k$,则可以使用 $\text{mindis}_p+(q-p)\times b$ 来更新 $\text{mindis}_q$。(此处将在 $\bmod\ a$ 意义下的环展开,并认为 $p$ 是其中第一个元素,则 $q>p$)
这显然可以使用单调队列来实现。
还剩下一个问题,求解完当前等差数列后,如何跳转到下一个等差数列?对于一个首项为 $a_1$ 的等差数列,求解完后 $\text{mindis}$ 是在 $\bmod\ a_1$ 的意义下,而下一个首项为 $a_2$ 的等差数列中需在 $\bmod\ a_2$ 意义下进行。
首先,可以使用 $\text{mindis}_x$ 转移到 $\text{mindis'}_{\text{mindis}_x\bmod a_2}$,同时,还需要考虑每一个长度为 $a_1$ 的转移,即 $\text{mindis}_x+k\times a_1$ 也可能更新 $\text{mindis'}$。这和上面是类似的问题,但是由于没有项数限制,所以甚至不用单调队列,记一个前缀 $\min$ 即可。
时间复杂度 $\mathcal{O}(T\times n\log n)$。
//Code By CXY07
#include<bits/stdc++.h>
using namespace std;
//#define FILE
#define int long long
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define LINE() cout << "LINE = " << __LINE__ << endl
#define debug(x) cout << #x << " = " << x << endl
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define min(a, b) (a < b ? a : b)
#define inv(x) qpow((x), mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define ull unsigned long long
#define pii pair<int, int>
#define LL long long
#define mp make_pair
#define pb push_back
#define scd second
#define vec vector
#define fst first
#define endl '\n'
const int MAXN = 5e5 + 10;
const int INF = 1.1e18;
const double PI = acos(-1);
const double eps = 1e-6;
//const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;
int T, n, w, mod, Ans;
int fail[MAXN], period[MAXN], cnt;
int dis[MAXN];
char s[MAXN];
template<typename T> inline bool read(T &a) {
a = 0; char c = getchar(); int f = 1;
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
a *= f;
return 1;
}
template<typename A, typename ...B>
inline bool read(A &x, B &...y) {return read(x) && read(y...);}
int Gcd(int x, int y) {
if(!y) return x;
return Gcd(y, x % y);
}
void clear() {
memset(dis, 0x3f, sizeof dis);
memset(fail, 0, sizeof fail);
}
void GetPeriod() {
fail[1] = 0, fail[0] = -1, cnt = 0;
for(int i = 2, j; i <= n; ++i) {
j = fail[i - 1];
while((~j) && s[j + 1] != s[i]) j = fail[j];
fail[i] = ++j;
}
for(int i = fail[n]; ~i; i = fail[i]) period[++cnt] = n - i;
}
void _run(int p, int gap, int c) { // i % mod == p, mod...mod + gap * c
static int t[MAXN], m, pos; m = 0, pos = -1;
static int q[MAXN], head, tail; head = 1, tail = 0;
for(int i = p; ;) {
if(pos == -1 || dis[i] < dis[pos]) pos = i;
i = (i + gap) % mod; if(i == p) break;
}
for(int i = pos; ;) {
t[++m] = i; i = (i + gap) % mod;
if(i == pos) break;
} assert(m == mod / Gcd(gap, mod));
q[++tail] = 1;
for(int i = 2; i <= m; ++i) {
while(head <= tail && i - q[head] > c) head++;
if(head <= tail && i - q[head] <= c) dis[t[i]] = min(dis[t[i]], dis[t[q[head]]] + (i - q[head]) * gap + mod);
while(head <= tail && dis[t[i]] - i * gap < dis[t[q[tail]]] - q[tail] * gap) tail--;
q[++tail] = i;
}
}
void run(int gap, int c) { // mod...mod + gap * c
int lim = Gcd(gap, mod);
for(int i = 0; i < lim; ++i) _run(i, gap, c);
}
void trans(int _new) { // mod -> _new
static int _d[MAXN]; memset(_d, 0x3f, sizeof _d);
static int t[MAXN], m;
int lim = Gcd(mod, _new);
for(int i = 0; i < mod; ++i) _d[dis[i] % _new] = min(_d[dis[i] % _new], dis[i]);
for(int p = 0, pos, mn; p < lim; ++p) {
pos = -1, m = 0, mn = INF;
for(int i = p; ;) {
if(pos == -1 || _d[i] < _d[pos]) pos = i;
i = (i + mod) % _new; if(i == p) break;
}
for(int i = pos; ;) {
t[++m] = i;
i = (i + mod) % _new; if(i == pos) break;
} assert(m == _new / lim);
for(int i = 1; i <= m; ++i) {
_d[t[i]] = min(_d[t[i]], mn + i * mod);
mn = min(mn, _d[t[i]] - i * mod);
}
}
mod = _new;
memcpy(dis, _d, sizeof dis);
}
void calc() {
int L = 1, R; mod = period[L];
while(L < cnt) {
R = L + 1;
while(R + 1 <= cnt && period[R + 1] - period[R] == period[R] - period[R - 1]) R++;
run(period[L + 1] - period[L], R - L);
L = R + 1; if(L <= cnt) trans(period[L]);
}
}
void solve() {
Ans = 0;
clear(); read(n), read(w), scanf("%s", s + 1); w -= n;
GetPeriod(); dis[0] = 0; calc();
for(int i = 0; i < mod; ++i)
if(dis[i] <= w) Ans = Ans + (w - dis[i]) / mod + 1;
printf("%lld\n", Ans);
}
signed main () {
#ifdef FILE
freopen("P4156.in", "r", stdin);
freopen("P4156.out", "w", stdout);
#endif
read(T);
while(T--) solve();
return 0;
}