【题解】[WC2021] 表达式求值
趣味妙题
题目链接:[WC2021] 表达式求值
题意:
给定 $m$ 个长度为 $n$ 的数组,现在有两种二元操作,表示两个数组每一位取 $\max$ 或 $\min$。
现给定一个包括该两个二元操作与 "$?$" 的表达式 $E$,求将 "$?$" 任意填写后,得到的所有数组的元素的和对 $10^9+7$ 取模。
$1\le n\le 5\times 10^4,1\le |E|\le 5\times 10^4,\ 1\le m\le 10$。
这题后半部分实际上是我高一军训期间无聊,出给机房同学的简单题,考场上瞬间联想到这个,然后被自己给震惊到。赞赞
考场只写了 $\text{70pts}$,但是实际上做法和正解只差一点了。 真的亏
首先肯定不能直接对着表达式做,所以建出表达式树。那么现在叶子上有若干值,每一个非叶子节点有一个 $\max,\min$ 或 "$?$" 的标记,表示对两个儿子做该操作。
发现每一位是独立的,因为二元操作都是按位操作的,所以枚举每一位,问题转化为叶子上是一个数,而非数组。
考虑一个错误的想法,如果考场上你以为这个答案可以利用单调性,然后二分一个数 $x$,那么就可以套路地将 $\ge x$ 的数写作 $1$,$<x$ 的部分写作 $0$,这样一个 $\max,\min$ 就变成了按位或、按位与。
但你发现这样会算重,因为二分的时候你不知道要给答案加入多少,所以不能二分。
考虑到 $m\le 10$,那就枚举好了。直接将某一位的 $m$ 个值从小到大排序,依次按照这 $m$ 个值进行划分。值的总和只需要进行简单容斥就可以轻松求出。
至于上述的计数 $\text{dp}$,设 $\text{dp}_{x,0/1}$ 为当前在 $x$ 点,数是 $0/1$ 的方案数,随便转移一下,最后答案就是 $\text{dp}_{\text{root},1}$,单次 $\mathcal{O}(|E|)$。
这样的复杂度是 $\mathcal{O}(nm|E|)$,加上没有”$?$“ 的暴力就可以拿到 $\text{70pts}$。由于本人当时还对 $\texttt{T1}$ 毫无思路,所以就直接走了,现在回想起来还是有些可惜的。
如何优化复杂度呢?每一位单独枚举是不太能优化的了,所以考虑从 $m$ 入手。
$m\le 10$,为什么不直接 $2^m$ 暴力枚举在某一种划分的方案下,每一个值是 $0/1$ 的方案数呢?预处理下来,之后每次查询就只需要 $\mathcal{O}(1)$ 了。
于是就可以通过本题,时间复杂度 $\mathcal{O}(2^m|E|+nm\log m)$。
$\text{Code:}$
//Code By CXY07
#include<bits/stdc++.h>
using namespace std;
//#define FILE
//#define int long long
#define debug(x) cout << #x << " = " << x << endl
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define LINE() cout << "LINE = " << __LINE__ << endl
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) qpow((x),mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define vec vector
const int MAXN = 5e4 + 10;
const int MAXM = 11;
const int INF = 2e9;
const double PI = acos(-1);
const double eps = 1e-6;
const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;
int n, m, cnt, top;
int a[MAXM][MAXN], opt[MAXN], st[MAXN];
int ls[MAXN], rs[MAXN];
int id[MAXM], p2[MAXM], sav[1 << MAXM];
LL dp[MAXN][2], Ans;
char stk[MAXN], s[MAXN], og[MAXN];
pii now[MAXM];
template<typename T> inline bool read(T &a) {
a = 0; char c = getchar(); int f = 1;
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
a *= f;
return 1;
}
void Modadd(LL &x, LL b) {
x += b;
if(x >= mod) x -= mod;
}
void build() {
scanf("%s", og + 1);
int len = strlen(og + 1), pos = 0;
for(int i = 1; i <= len; ++i) {
if('0' <= og[i] && og[i] <= '9') s[++pos] = og[i];
else {
if(og[i] == '<' || og[i] == '>' || og[i] == '?') {
while(top && stk[top] != '(' && stk[top] != ')') s[++pos] = stk[top--];
stk[++top] = og[i];
} else {
if(og[i] == '(') stk[++top] = '(';
else {
while(top && stk[top] != '(') s[++pos] = stk[top--];
top--;
}
}
}
}
while(top) if(stk[top] != '(' && stk[top] != ')') s[++pos] = stk[top--];
top = 0;
for(int i = 1; i <= pos; ++i) {
cnt++;
if('0' <= s[i] && s[i] <= '9') opt[cnt] = s[i] - '0' + 1, st[++top] = cnt;
else {
if(s[i] == '?') opt[cnt] = 11;
else if(s[i] == '<') opt[cnt] = 12;
else opt[cnt] = 13;
ls[cnt] = st[top - 1], rs[cnt] = st[top];
top = top - 2; st[++top] = cnt;
}
}
}
void calc(int x) {
dp[x][0] = dp[x][1] = 0;
if(opt[x] <= m) dp[x][id[opt[x]]] = 1;
int u = ls[x], v = rs[x];
if(!u || !v) return;
calc(u), calc(v);
if(opt[x] == 11 || opt[x] == 12) {
Modadd(dp[x][0], dp[u][0] * dp[v][0] % mod);
Modadd(dp[x][0], dp[u][0] * dp[v][1] % mod);
Modadd(dp[x][0], dp[u][1] * dp[v][0] % mod);
Modadd(dp[x][1], dp[u][1] * dp[v][1] % mod);
}
if(opt[x] == 11 || opt[x] == 13) {
Modadd(dp[x][0], dp[u][0] * dp[v][0] % mod);
Modadd(dp[x][1], dp[u][0] * dp[v][1] % mod);
Modadd(dp[x][1], dp[u][1] * dp[v][0] % mod);
Modadd(dp[x][1], dp[u][1] * dp[v][1] % mod);
}
}
inline int DP() {calc(cnt); return dp[cnt][1];}
inline void solve() {
for(int S = 0; S < p2[m]; ++S) {
for(int k = 0; k < m; ++k) id[k + 1] = (S & p2[k]) ? 1 : 0;
sav[S] = DP();
}
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= m; ++j) now[j] = mp(a[j][i], j);
sort(now + 1, now + m + 1);
now[0] = mp(0, 0);
int S = p2[m] - 1, pos = 1;
for(int j = 1; j <= m; ++j) {
while(pos < j && now[pos].fst < now[j].fst) S ^= p2[now[pos++].scd - 1];
Modadd(Ans, 1ll * (now[j].fst - now[j - 1].fst) * sav[S] % mod);
}
}
printf("%lld\n", Ans);
}
signed main () {
#ifdef FILE
freopen("expr.in","r",stdin);
freopen("expr.out","w",stdout);
#endif
read(n), read(m);
p2[0] = 1;
for(int i = 1; i <= m; ++i) p2[i] = p2[i - 1] << 1;
for(int i = 1; i <= m; ++i)
for(int j = 1; j <= n; ++j) read(a[i][j]);
build(); solve();
return 0;
}