CF 1395E(集合哈希)
Contents
题目链接
https://codeforces.com/contest/1395/problem/E
题意
给定一个 directed and weighted graph,$2 \leq n \leq 2\cdot10^5, 2 \leq m \leq \min(2\cdot10^5, n(n-1))$,每个vertex的 out-degree 最多为 $1\leq k \leq 9$,每个edge的weight均不相同。
现在定义一个tuple $(c_1,c_2,…,c_k)$,其中, $\forall j \in [1,k]$,有 $1\leq c_j \leq j$。
且对于所有 out-degree 等于 $j$ 的vertex,只保留它的out-going edges中,weight第 $c_j$ 小的那个edge。
例如,对于vertex 5,有 $3$ 条out-going edges (以(u,v,w)的形式): $(5,2,233), (5,4,25), (5,6,999)$
如果 $c_3 = 1$,因为 vertex 5 具有 out-degree = $3$,所以 $c_3$ 生效,因为 $c_3 = 1$,所以保留第 $1$ 小的edge,也就是$(5,4,25)$。
现在求 $(c_1,c_2,…,c_k)$ 的数量,使得整个图是强连通的。
题解
首先可以发现,因为每个edge的weight都不同,所以对于每一个vertex来说,必然只能保留最多 $1$ 个out-going edge。所以整个图中,只有 $n$ 个edge。
又因为需要强连通,所以它只有可能是 一个环!所以只要看 每个vertex的 in-degree 是否都等于$1$即可!
再转化一下,我们只要看在 $(c_1,c_2,…,c_k)$ 的情况下,所有out-going edge所指向的vertex,并起来,形成的 可重复集合(multiset) 是否为 $\{1,2,3,…,n\}$ 即可!
我们可以预处理出 对于每一个 $j$,如果 $c_j = x$,所指向的vertex组成的集合。然后在枚举 $(c_1,c_2,…,c_k)$ 的时候,判断一下这些集合的并集是否为 $\{1,2,3,…,n\}$ 即可!
那么,如何快速的
- 判断集合是否相等 和
- 求可重复并集 呢?
使用Hashing!
我们要定义一种Hash函数,使得上述两个操作的速度为 $O(1)$ 。
在字符串哈希中,我们用字符的位置来hash,但是因为我们不关心集合中元素的顺序,所以可以用集合中元素的value来hash!
假设给定一个集合 $\{a_1,a_2,…,a_m\}$,定义哈希值为:$\sum\limits_{i=1}^mp^{a_i} = p^{a_1} + p^{a_2} + … + p^{a_m}$,这样:
- 判断集合是否相等:直接比较两个集合的哈希值
- 求两个集合的可重复并集:直接将两个集合的哈希值相加
代码
using namespace std;
#include <bits/stdc++.h>
const int mod = 1e9+7;
const int maxn = 2e5+5;
const ll p = 31;
int n,m,k;
vector<pii> adj[maxn]; // {w, to}
int out[maxn];
vector<int> deg[10]; // store vertices with deg x
ll pow31[maxn];
ll ha[10][10]; //ha[i][j] 代表 c_i=j时,所指向的vertex的并集的哈希值
ll tar = 0; // {1,2,3...,n}对应的哈希值
int ans = 0;
void init() {
pow31[0] = 1;
for (ll i = 1; i <= n; i++) pow31[i] = (pow31[i-1] * p) % mod;
for (ll i = 1; i <= n; i++) tar += pow31[i];
tar %= mod;
for (int i = 1; i <= k; i++) { // calculate all vertex with deg i
for (int j = 1; j <= i; j++) { // if c_i = j
for (int a : deg[i]) {
int to = adj[a][j-1].second;
(ha[i][j] += pow31[to]) %= mod;
}
}
}
}
void dfs(int dep, ll cur) { //cur: current hash value
if (dep == k+1) {
if (tar == cur) ans++;
return;
}
for (int j = 1; j <= dep; j++) {
dfs(dep+1, (cur + ha[dep][j]) % mod);
}
}
int main() {
cin >> n >> m >> k;
for (int i = 1; i <= m; i++) {
int u,v,w; cin >> u >> v >> w;
adj[u].push_back({w,v});
out[u]++;
}
for (int i = 1; i <= n; i++) {
sort(adj[i].begin(), adj[i].end());
deg[out[i]].push_back(i);
}
init();
dfs(1, 0);
cout << ans << endl;
}