Atcoder ABC 194F(数位DP,进制处理)
Contents
题目链接
https://atcoder.jp/contests/abc194/tasks/abc194_f
题意
给定一个数 $1\leq N \leq 16^{2\times10^5}$,求:
在 16进制 下,满足以下条件的整数 $x$ 数量:
- 拥有 Exactly $K$ 个不同的digit (例如 $x = 1F21$,就有3个不同的digit)
- $x \in [1,N]$
题解
我们使用 dp
,为了方便处理 $x \leq N$ 的问题,我们从最高位(Most significant digit)开始处理。
我们设 dp[i][j]
为,我们处理到了第 i
位,使用了 j
个不同的digit,且满足以下条件的数字 $x$ 的数量:
- $x$ 严格小于 $N$ 的前
i
位 - $x \neq 0$
注意,$x$ 不一定完全有
i
个digit,但是它至少有一个有效的digit。
由上,因为第一个条件,我们无论在这一位选择什么digit,都必然会仍然满足条件 $1$ 中的 严格小于 $N$。
所以我们可以得到第一个转移方程:
for (ll j = 2; j <= k; j++) {
dp[i][j] = (dp[i][j] + dp[i-1][j-1] * (16-j+1)) % mod;
dp[i][j] = (dp[i][j] + dp[i-1][j] * j) % mod;
}
我们发现,这样只考虑到了前 i-1
位至少有一个有效digit的情况,没有考虑全部为前缀0的情况。
如果前 i-1
位全是前缀0,那么无论在这一位选择任何数,都必然满足条件 $1$ 中的 严格小于 $N$。
所以我们可以得到第二个转移方程:
dp[i][1] = (dp[i-1][1] + 15) % mod;
注意到,我们的 dp
数组里,并没有包含 $x$ 的前 i-1
位 完全等同于 $N$ 的前 i-1
位的情况。我们需要单独处理它!
我们在当前这一位(第i
位),能够计入 dp
数组的只有 第 i
位小于 $N$ 的第 i
位的情况。
并且,我们可以很容易得到 $N$ 的前 i-1
位有多少个不同的digit(这样就得到了 dp
数组里的 j
)。
所以,我们只要枚举一下,第 i
位可以选择的所有digit,计算一下不同的digit数量,然后计入 dp
数组即可。
所以我们得到了第三个转移方程:
map<char, int> dict; // 用于将 0123456789ABCDEF map到对应的int
set<int> used; // 记录N的前i-1位用了哪些数字
char c = s[i-1]; // N的第i位数字
int cur = dict[c];
int pre = used.size();
for (int j = 0; j < cur; j++) { //枚举所有可以选择的digit
if (!used.count(j)) { // 计算不同digit的数量
dp[i][pre+1]++;
} else dp[i][pre]++;
}
used.insert(cur);
最后别忘记,看一下 $N$ 自己是否也满足条件(拥有Exactly $K$ 个不同的digit)。
小结:
本题是一个非常不错的dp,主要用了以下几个关键的trick:
- 对于 Exactly $K$ 不同的digit,我们不关心具体是哪几个digit,只需知道不同的digit数量就可以计数了。
- 通过限制
dp
数组的定义,让dp
数组仅记录 严格小于 $N$ 的数字,方便计数。- 单独处理 $x$ 与 $N$ 的前
i
位相同的情况,并且将符合条件的计数加到dp
数组当中去。需要重点关注的是这个 DP并不记录最终答案,而是记录 满足某种条件的部分答案,其余特殊情况单独处理的思想。
代码
using namespace std;
#include <bits/stdc++.h>
#define ll long long
const int mod = 1e9+7;
const int maxn = 2e5+5;
ll dp[maxn][18];
string s;
int k;
map<char, int> dict;
set<int> used;
int main() {
fastio;
for (int i = 0; i <= 9; i++) dict[(char)(i+'0')] = i;
int o = 10;
for (char c = 'A'; c <= 'F'; c++) dict[c] = o++;
cin >> s >> k;
int n = s.size();
dp[1][1] = dict[s[0]] - 1;
used.insert(dict[s[0]]);
for (ll i = 2; i <= n; i++) {
char c = s[i-1];
dp[i][1] = (dp[i-1][1] + 15) % mod;
int cur = dict[c];
for (ll j = 2; j <= k; j++) {
dp[i][j] = (dp[i][j] + dp[i-1][j-1] * (16-j+1)) % mod;
dp[i][j] = (dp[i][j] + dp[i-1][j] * j) % mod;
}
int pre = used.size();
for (int j = 0; j < cur; j++) {
if (!used.count(j)) {
dp[i][pre+1]++;
} else dp[i][pre]++;
}
used.insert(cur);
}
if (used.size() == k) dp[n][k]++;
ll ans = dp[n][k] % mod;
cout << ans << endl;
}