题意:
给出一颗树,每个结点有取值范围\([1,D]\)。
现在有限制条件:对于一个子树,根节点的取值要大于等于子数内各结点的取值。
问有多少种取值方案。
思路:
- 手画一下发现,对于一颗大小为\(sz\)的数,最终的答案为一个\(sz+1\)次为最高次幂的多项式。
- 因为节点数\(n\leq 3000\),所以暴力求出后插值即可。
简略证明:对于一个链,显然,一个长度为\(x\)的链,最终的结果为\(x+1\)次的多项式;考虑两条链的合并:长度为\(x\)的链和长度为\(y\)的链,显然两者相乘最终为\(x+y+2\)次的多项式,因为合并过后会多一个父节点,那么就是有\(x+y+1\)个点。
归纳一下就有上面说的结论了。
代码如下:
/* * Author: heyuhhh * Created Time: 2019/11/18 20:20:04 */ #include <bits/stdc++.h> #define MP make_pair #define fi first #define se second #define sz(x) (int)(x).size() #define all(x) (x).begin(), (x).end() #define INF 0x3f3f3f3f #define Local #ifdef Local #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0) void err() { std::cout << '\n'; } template<typename T, typename...Args> void err(T a, Args...args) { std::cout << a << ' '; err(args...); } #else #define dbg(...) #endif void pt() {std::cout << '\n'; } template<typename T, typename...Args> void pt(T a, Args...args) {std::cout << a << ' '; pt(args...); } using namespace std; typedef long long ll; typedef pair<int, int> pii; //head const int N = 3005, MOD = 1e9 + 7; ll qpow(ll a, ll b) { ll ans = 1; while(b) { if(b & 1) ans = ans * a % MOD; a = a * a % MOD; b >>= 1; } return ans; } int n, D; vector <int> g[N]; int res[N]; int pre[N][N]; void dfs(int u, int fa) { int son = 0; for(auto v : g[u]) if(v != fa) { dfs(v, u); ++son; } if(!son) { for(int i = 1; i <= n; i++) pre[u][i] = i; } else { for(int i = 1; i <= n; i++) res[i] = 1; for(auto v : g[u]) if(v != fa) { for(int i = 1; i <= n; i++) res[i] = 1ll * res[i] * pre[v][i] % MOD; } for(int i = 1; i <= n; i++) pre[u][i] = (pre[u][i - 1] + res[i]) % MOD; } } struct Lagrange { static const int SIZE = 3005; ll f[SIZE], fac[SIZE], inv[SIZE], pre[SIZE], suf[SIZE]; int n; inline void add(ll &x, int y) { x += y; if(x >= MOD) x -= MOD; } void init(int _n) { n = _n; fac[0] = 1; for (int i = 1; i < SIZE; ++i) fac[i] = fac[i - 1] * i % MOD; inv[SIZE - 1] = qpow(fac[SIZE - 1], MOD - 2); for (int i = SIZE - 1; i >= 1; --i) inv[i - 1] = inv[i] * i % MOD; f[0] = 0; } ll calc(ll x) { if (x <= n) return f[x]; pre[0] = x % MOD; for (int i = 1; i <= n; ++i) pre[i] = pre[i - 1] * ((x - i) % MOD) % MOD; suf[n] = (x - n) % MOD; for (int i = n - 1; i >= 0; --i) suf[i] = suf[i + 1] * ((x - i) % MOD) % MOD; ll res = 0; for (int i = 0; i <= n; ++i) { ll tmp = f[i] * inv[n - i] % MOD * inv[i] % MOD; if (i) tmp = tmp * pre[i - 1] % MOD; if (i < n) tmp = tmp * suf[i + 1] % MOD; if ((n - i) & 1) tmp = MOD - tmp; add(res, tmp); } return res; } }lagrange; void run(){ for(int i = 2; i <= n; i++) { int x; cin >> x; g[i].push_back(x); g[x].push_back(i); } lagrange.init(n); dfs(1, 0); for(int i = 1; i <= n; i++) lagrange.f[i] = pre[1][i]; int ans = lagrange.calc(D); cout << ans; } int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cout << fixed << setprecision(20); while(cin >> n >> D) run(); return 0; }