Luogu P5664 Emiya家今天的饭(计数,dp)
Contents
题目链接
https://www.luogu.com.cn/problem/P5664
题意
有 $n$ 种烹饪方法,$m$ 种主要食材。每道菜都只用 恰好一种 烹饪方法和主要食材,同时对于 每种烹饪方法 $i$ 和 主要食材 $j$,有 $a_{ij}$ 种不同的菜。所以总共有 $\sum\limits_{i=1}^n \sum\limits_{j=1}^m a_{ij}$ 道不同的菜。
现在需要求做菜方案 (设总共有 $k$ 道菜,$k$ 可以取任何数),满足:
- $k \geq 1$
- 每道菜的烹饪方法 $i$ 各不相同
- 每种主要食材 $j$ 最多在一半($\lfloor \frac{k}{2} \rfloor$)的菜中出现
求满足上述条件的做菜方案个数?
$1 \leq n \leq 100, 1 \leq m \leq 2000, 0 \leq a_{ij} < 998244353$
题解
首先考虑条件1和2,满足这些条件的总方案数有 $((s_1+1) * (s_2+1) * (s_3+1) * … * (s_n+1) - 1)$ 种,其中 $s_i = \sum\limits_{j=1}^m a_{ij}$
(因为对于每种烹饪方法 $i$,还可以 不选,所以是 $(s_i+1)$,最后减去 全部不选 的情况)
这样,我们减去 不满足条件3 的方案数即可!
我们枚举超过限制的主要食材 $j$,然后设 $dp[i][k]$ 为:当前到了第 $i$ 种烹饪方法,使用了 $k$ 种主要食材 $j$ 的方案数。
那么问题关键在于,对于某一种烹饪方法 $i$,我们可以不选任何菜,这怎么办?
我们假设有 $t$ 个不选的,那么总共就选了 $n-t$ 个菜,要保证 $k > \lfloor \frac{n-t}{2} \rfloor$,即 $2k + t > n$。
所以,我们可以改变一下状态的定义,我们可以将 不选 变成 选了一种主要食材 $j$,而 选择主要食材 $j$ 就变成 选择了两个主要食材 $j$。这样,只要满足 2 * 选择主要食材 + 不选的数量 = 2k + t > n
,就不满足条件3了!
转移方程就很好写了,对于每一种烹饪方法 $i$,有 $3$ 种决策方案:
- 选择主要食材 $j$:
dp[i][k] += dp[i-1][k-2] * a[i][j];
- 不选任何菜:
dp[i][k] += dp[i-1][k-1];
- 选择非主要食材:
dp[i][k] += (dp[i-1][k] * (sum[i] - a[i][j]));
注意,不能将
dp
数组变成一维然后倒序转移!比如,在第三种转移时,
(sum[i] - a[i][j]) == 0
,那么如果是二维的,此时dp[i][k] = 0
,而如果是一维的,就会变成dp[i][k] = dp[i-1][k]
。
代码
#include <bits/stdc++.h>
#define ll long long
const int mod = 998244353;
const int maxn = 1e5+5;
ll dp[103][205];
ll sum[105];
ll a[103][2003];
int n,m;
int main() {
fastio;
ll ans = 1;
cin >> n >> m;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
cin >> a[i][j];
sum[i] = (sum[i] + a[i][j]) % mod;
}
ans = (ans * (sum[i]+1LL)) % mod;
}
ll delta = 1LL;
for (int j = 1; j <= m; j++) {
memset(dp, 0, sizeof(dp));
dp[0][0] = 1;
for (int i = 1; i <= n; i++) {
for (int k = 0; k <= 2*n; k++) {
dp[i][k] = (dp[i][k] + dp[i-1][k] * (sum[i] - a[i][j]) % mod) % mod;
if (k >= 1) {
(dp[i][k] += dp[i-1][k-1]) %= mod;
}
if (k >= 2) {
dp[i][k] = (dp[i][k] + dp[i-1][k-2] * a[i][j] % mod) % mod;
}
}
}
for (int k = n+1; k <= 2*n; k++) (delta += dp[n][k]) %= mod;
}
ans -= delta;
(ans %= mod) += mod;
ans %= mod;
cout << ans << endl;
}