题目链接

https://atcoder.jp/contests/abc194/tasks/abc194_f

题意

给定一个数 $1\leq N \leq 16^{2\times10^5}$,求:

16进制 下,满足以下条件的整数 $x$ 数量:

  1. 拥有 Exactly $K$ 个不同的digit (例如 $x = 1F21$,就有3个不同的digit)
  2. $x \in [1,N]$
题解

我们使用 dp,为了方便处理 $x \leq N$ 的问题,我们从最高位(Most significant digit)开始处理。

我们设 dp[i][j] 为,我们处理到了第 i 位,使用了 j 个不同的digit,且满足以下条件的数字 $x$ 的数量:

  1. $x$ 严格小于 $N$ 的前 i
  2. $x \neq 0$

注意,$x$ 不一定完全有 i 个digit,但是它至少有一个有效的digit。

由上,因为第一个条件,我们无论在这一位选择什么digit,都必然会仍然满足条件 $1$ 中的 严格小于 $N$。

所以我们可以得到第一个转移方程:

for (ll j = 2; j <= k; j++) {
    dp[i][j] = (dp[i][j] + dp[i-1][j-1] * (16-j+1)) % mod;
    dp[i][j] = (dp[i][j] + dp[i-1][j] * j) % mod;
}

我们发现,这样只考虑到了前 i-1 位至少有一个有效digit的情况,没有考虑全部为前缀0的情况。

如果前 i-1 位全是前缀0,那么无论在这一位选择任何数,都必然满足条件 $1$ 中的 严格小于 $N$。

所以我们可以得到第二个转移方程:

dp[i][1] = (dp[i-1][1] + 15) % mod;

注意到,我们的 dp 数组里,并没有包含 $x$ 的前 i-1完全等同于 $N$ 的前 i-1 位的情况。我们需要单独处理它!

我们在当前这一位(第i位),能够计入 dp 数组的只有 i 位小于 $N$ 的第 i的情况。

并且,我们可以很容易得到 $N$ 的前 i-1 位有多少个不同的digit(这样就得到了 dp 数组里的 j)。

所以,我们只要枚举一下,第 i 位可以选择的所有digit,计算一下不同的digit数量,然后计入 dp 数组即可。

所以我们得到了第三个转移方程:

map<char, int> dict;   // 用于将 0123456789ABCDEF map到对应的int
set<int> used;  // 记录N的前i-1位用了哪些数字

char c = s[i-1];  // N的第i位数字
int cur = dict[c];
int pre = used.size();
for (int j = 0; j < cur; j++) {  //枚举所有可以选择的digit
    if (!used.count(j)) {   // 计算不同digit的数量
        dp[i][pre+1]++;
    } else dp[i][pre]++;
}
used.insert(cur);

最后别忘记,看一下 $N$ 自己是否也满足条件(拥有Exactly $K$ 个不同的digit)。

小结:

本题是一个非常不错的dp,主要用了以下几个关键的trick:

  1. 对于 Exactly $K$ 不同的digit,我们不关心具体是哪几个digit,只需知道不同的digit数量就可以计数了。
  2. 通过限制 dp 数组的定义,让 dp 数组仅记录 严格小于 $N$ 的数字,方便计数。
  3. 单独处理 $x$ 与 $N$ 的前 i 位相同的情况,并且将符合条件的计数加到 dp 数组当中去。

需要重点关注的是这个 DP并不记录最终答案,而是记录 满足某种条件的部分答案其余特殊情况单独处理的思想。

代码
using namespace std;
#include <bits/stdc++.h>

#define ll long long

const int mod = 1e9+7;
const int maxn = 2e5+5;

ll dp[maxn][18];
string s;
int k;

map<char, int> dict;
set<int> used;

int main() {
    fastio;

    for (int i = 0; i <= 9; i++) dict[(char)(i+'0')] = i;
    int o = 10;
    for (char c = 'A'; c <= 'F'; c++) dict[c] = o++;

    cin >> s >> k;
    int n = s.size();
    dp[1][1] = dict[s[0]] - 1;
    used.insert(dict[s[0]]);

    for (ll i = 2; i <= n; i++) {
        char c = s[i-1];
        dp[i][1] = (dp[i-1][1] + 15) % mod;
        int cur = dict[c];

        for (ll j = 2; j <= k; j++) {
            dp[i][j] = (dp[i][j] + dp[i-1][j-1] * (16-j+1)) % mod;
            dp[i][j] = (dp[i][j] + dp[i-1][j] * j) % mod;
        }

        int pre = used.size();
        for (int j = 0; j < cur; j++) {
            if (!used.count(j)) {
                dp[i][pre+1]++;
            } else dp[i][pre]++;
        }

        used.insert(cur);
    }
    if (used.size() == k) dp[n][k]++;
    ll ans = dp[n][k] % mod;
    cout << ans << endl;
}