點分治小結
演算法介紹
點分治,顧名思義,是一種對點進行分治的資料結構。(樹上的點)
多用於在樹上進行有限制的路徑計數。
比如:求樹上長度小於$ k$ 的簡單路徑條數。\((n \leq 10000)\)
直接做肯定是補星的。所以就需要點分治這種東西了。
需要統計的路徑肯定有這麼兩種:
- 1.經過根節點$ root $的路徑
- 2.不經過根節點\(root\) 的路徑
顯然第二種路徑對於某個節點\(u\) ,屬於第一種路徑。所以分治解決即可。
我們來考慮第一種情況如何解決。
處理出一個數組\(d\) ,表示從當前根節點\(u\) ,到各個子節點的距離。
那麼我們要統計的顯然就是\(d[u]+d[v]\leq k\) 的路徑\((u,v)\) 的個數。
這個東西可以通過在dfs求這個陣列時順便把所有的\(d\) 值記錄下來,排序之後讓他們具有單調性。
然後雙指標掃一下就好(合法狀態就是\(d[l]+d[r]\leq k\) )那麼當指標在\(l\) 時,對答案的貢獻就是\(r-l\) (不能重複選自己,所以不+1)
然後現在考慮一種情況。當\(u,v\) 都在當前根節點的同一個子樹裡面。這樣子的話,路徑\((u,v)\) 如果經過根節點就不是一條簡單路徑了(重邊)。如何解決呢?
容斥的思想!
對於每個子樹,分別處理它其中的子節點的d值,給答案減去就行了!
程式碼大概就長這個樣子
void dfs(int u) { vis[u] = 1; ans += solve(u, 0); //所有情況 for(int i = head[u]; i; i = e[i].nxt) { if(vis[e[i].to]) continue; int v = e[i].to; ans -= solve(v, e[i].v); //減掉不合法情況 //下面是找重心的程式碼,後面會解釋為什麼要找重心 now_sz = inf, root = 0; sz = siz[v]; find_root(v, 0); dfs(root); } }
先不管為什麼要找重心。我們總結一下演算法流程:
- 1.找一個根節點root
- 2.對root計算出d陣列並計算答案
- 3.把root刪了,對root的各個子樹執行流程1,2
複雜度是多少呢?粗略估計一下是\(O(Tnlogn) \) ,\(T\) 是樹的層數。
顯然我們要讓這個樹優美一點,身材圓潤一點,不能瘦成一條鏈,不然複雜度就變成\(O(n^2logn) \) 了。
那這個根節點怎麼找呢?樹的重心 !
將重心當做根節點,可以保證樹是\(log\) 層的!
那麼複雜度就是$O(nlog^2n) $了!
還有就是關於點分治這裡的重心有兩種找法。一種就是上面那樣的,另外一種就是改了一句
sz = siz[v];
->sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
實際上第二種才是對的,因為v可能在上次處理siz陣列時是u的父親(這是一棵無根樹!)
但是複雜度並不會退化qwq,有神仙證明了。連結
例題:
POJ1741 tree
真正的模板題。就是我上面提到的那個問題。
直接點分一下就好了。每次將距離排序一下,然後雙指標掃一掃,每次合法答案就是r-l,容斥一下將不合法情況減去即可。注意找重心不要寫錯不然複雜度就炸了。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define inf 0x3f3f3f3f #define ll long long #define N 100010 inline void in(int &x) { x = 0; int f = 1; char c = getchar(); while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); } while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); } x *= f; } int n, k, d[N], cnt, head[N], ans; int vis[N], siz[N]; struct edge { int to, nxt, v; }e[N<<1]; void ins(int u, int v, int w) { e[++cnt] = (edge) {v, head[u], w}; head[u] = cnt; } int now_sz = inf, root = 0, sz; void find_root(int u, int fa) { siz[u] = 1; int res = 0; for(int i = head[u]; i; i = e[i].nxt) { if(vis[e[i].to] || e[i].to == fa) continue; int v = e[i].to; find_root(v, u); siz[u] += siz[v]; res = max(res, siz[v]); } res = max(res, sz - siz[u]); if(res < now_sz) now_sz = res, root = u; } int a[N], tot; void get_dis(int u, int fa) { a[++tot] = d[u]; for(int i = head[u]; i; i = e[i].nxt) { if(vis[e[i].to] || e[i].to == fa) continue; int v = e[i].to; d[v] = d[u] + e[i].v; get_dis(v, u); } } int solve(int u, int dis) { d[u] = dis; tot = 0; get_dis(u, u); sort(a + 1, a + tot + 1); int l = 1, r = tot, res = 0; for(; l < r; ++l) { while(l < r && a[l] + a[r] > k) --r; if(l < r) res += r - l; } return res; } void dfs(int u) { vis[u] = 1; ans += solve(u, 0); for(int i = head[u]; i; i = e[i].nxt) { if(vis[e[i].to]) continue; int v = e[i].to; ans -= solve(v, e[i].v); now_sz = inf, root = 0; sz = siz[v]; find_root(v, 0); dfs(root); } } int main() { while(~scanf("%d%d", &n, &k) && n && k) { ans = 0; cnt = 0; memset(head, 0, sizeof(head)); memset(vis, 0, sizeof(vis)); for(int i = 1; i < n; ++i) { int u, v, w; in(u), in(v), in(w); ins(u, v, w), ins(v, u, w); } dfs(1); printf("%d\n", ans); } }
BZOJ2152: 聰聰可可
求倍數為3的路徑數。
考慮\(mod\ 3\) 意義下的路徑,為0顯然可以互相拼起來,貢獻是\(sum[0]^2\) 。1和2可以互相拼,而且起點終點互換,所以貢獻是\(sum[1]*sum[2]*2\) ,點分治計算這兩個即可。總方案數是\(n^2\) ,所以答案就是\(\frac{sum}{n^2}\)
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define inf 0x3f3f3f3f #define ll long long #define N 100010 inline void in(int &x) { x = 0; int f = 1; char c = getchar(); while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); } while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); } x *= f; } int n, k, d[N], cnt, head[N], ans; int vis[N], siz[N], sum[3]; struct edge { int to, nxt, v; }e[N<<1]; void ins(int u, int v, int w) { e[++cnt] = (edge) {v, head[u], w}; head[u] = cnt; } int now_siz, sz, root; void find_root(int u, int fa) { siz[u] = 1; int res = 0; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa || vis[v]) continue; find_root(v, u); siz[u] += siz[v]; res = max(res, siz[v]); } res = max(res, sz - siz[u]); if(res < now_siz) now_siz = res, root = u; } void get_dis(int u, int fa) { sum[d[u]%3]++; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(vis[v] || v == fa) continue; d[v] = d[u] + e[i].v; get_dis(v, u); } } int solve(int u, int dis) { d[u] = dis; sum[0] = sum[1] = sum[2] = 0; get_dis(u, u); return sum[0] * sum[0] + sum[1] * sum[2] * 2; } void dfs(int u) { ans += solve(u, 0); vis[u] = 1; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(vis[v]) continue; ans -= solve(v, e[i].v); now_siz = inf; sz = siz[v]; root = 0; find_root(v, u); dfs(root); } } int main() { in(n); for(int i = 1; i < n; ++i) { int u, v, w; in(u), in(v), in(w); ins(u, v, w), ins(v, u, w); } now_siz = inf; root = 0; sz = n; find_root(1, 1); dfs(root); int now = n * n, g = __gcd(now, ans); printf("%d/%d\n", ans / g, now / g); }
LuoguP3806 【模板】點分治1
注意這題資料很水...
求長度為k的路徑是否存在。多次詢問(詢問數\(\leq 100\) )
這題效率有點奇怪...
自己估算了一下是\(O(mnlog^2n)\) 。
對長度正好k的話,其實用個桶標記就好了,實際上和小於k沒多大區別的。
考慮先將詢問離線,然後在點分治過程中對所有答案進行判定。處理出d[]表示到節點i到當前根的距離。那麼照例是拼路徑,但是現在不是求方案總數而是求有沒有這個方案,看起來不能容斥了。但是實際上可以的:考慮先對根u solve一遍,給所有詢問加上這次的結果,然後對每個子節點計算一遍,給所有詢問減掉這次的結果就好了。
具體的話看看程式碼吧
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define inf 0x3f3f3f3f #define ll long long #define N 100010 #define lim 10000000 inline void in(int &x) { x = 0; int f = 1; char c = getchar(); while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); } while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); } x *= f; } int top, n, m, d[N], cnt, head[N], ans[110]; int vis[N], siz[N], q[110], st[N], s[10000010]; struct edge { int to, nxt, v; }e[N<<1]; void ins(int u, int v, int w) { e[++cnt] = (edge) {v, head[u], w}; head[u] = cnt; } int now_sz = inf, root, sz; void find_root(int u, int fa) { siz[u] = 1; int res = 0; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa || vis[v]) continue; find_root(v, u); res = max(res, siz[v]); siz[u] += siz[v]; } res = max(res, sz - siz[u]); if(res < now_sz) now_sz = res, root = u; } void get_dis(int u, int fa) { st[++top] = d[u]; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa || vis[v]) continue; d[v] = d[u] + e[i].v; get_dis(v, u); } } void solve(int u, int dis, int op) { top = 0; d[u] = dis; get_dis(u, 0); for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]++; for(int i = 1; i <= m; ++i) { for(int j = 1; j <= top; ++j) if(q[i] >= st[j]) ans[i] += s[q[i] - st[j]] * op; } for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]--; } void dfs(int u) { vis[u] = 1; solve(u, 0, 1); for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(vis[v]) continue; top = 0; d[v] = e[i].v; solve(v, e[i].v, -1); now_sz = inf, root = 0, sz = siz[v]; find_root(v, u); dfs(root); } } int main() { in(n), in(m); for(int i = 1; i < n; ++i) { int u, v, w; in(u), in(v), in(w); ins(u, v, w), ins(v, u, w); } for(int i = 1; i <= m; ++i) in(q[i]); sz = n; now_sz = inf; root = 0; find_root(1, 1); dfs(root); for(int i = 1; i <= m; ++i) puts(ans[i] ? "AYE" : "NAY"); }
CF161D Distance in Tree
求長度等於k的路徑數...就很煩....這種一般都要分類討論
需要分類討論一下,同樣是套路點分然後開個桶,然後分\(k-v[i]=v[i]\) 和不等兩種情況,顯然相等的話答案就是\(cnt[v[i]]*(cnt[v[i]]-1)/2\) .不相等的話用乘法原理考慮一下,\(cnt[v[i]]*cnt[k-v[i]]\) ,注意每次統計完之後就要把cnt清空。
#include <bits/stdc++.h> #define ll long long #define inf 0x3f3f3f3f #define il inline namespace io { #define in(a) a = read() #define out(a) write(a) #define outn(a) out(a), putchar('\n') #define I_int ll inline I_int read() { I_int x = 0, f = 1; char c = getchar(); while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); } while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); } return x * f; } char F[200]; inline void write(I_int x) { if (x == 0) return (void) (putchar('0')); I_int tmp = x > 0 ? x : -x; if (x < 0) putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0) putchar(F[--cnt]); } #undef I_int } using namespace io; using namespace std; #define N 100010 int n, k; int cnt, head[N], vis[N], d[N]; struct edge { int to, nxt; }e[N<<1]; void ins(int u, int v) { e[++cnt] = (edge) {v, head[u]}; head[u] = cnt; } int siz[N], now_sz = inf, root, sz; void find_root(int u, int fa) { siz[u] = 1; int res = 0; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa || vis[v]) continue; find_root(v, u); siz[u] += siz[v]; res = max(res, siz[v]); } res = max(res, sz - siz[u]); if(res < now_sz) now_sz = res, root = u; } int top, st[N], s[N]; void get_dis(int u, int fa) { st[++top] = d[u]; if(d[u] <= k) ++s[d[u]]; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(v == fa || vis[v]) continue; d[v] = d[u] + 1; get_dis(v, u); } } ll solve(int u, int dis) { d[u] = dis; top = 0; get_dis(u, 0); ll ans = 0; for(int i = 1; i <= top; ++i) if(st[i] <= k) { if(st[i] * 2 == k) ans += 1ll * s[st[i]] * (s[st[i]] - 1) / 2ll; else ans += 1ll * s[k - st[i]] * s[st[i]]; s[st[i]] = s[k - st[i]] = 0; } return ans; } ll ans = 0; void dfs(int u) { vis[u] = 1; ans += solve(u, 0); int totsiz = sz; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if(vis[v]) continue; ans -= solve(v, 1); sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v]; now_sz = inf; root = 0; find_root(v, 0); dfs(root); } } int main() { in(n), in(k); for(int i = 1; i < n; ++i) { int u = read(), v = read(); ins(u, v), ins(v, u); } now_sz = inf; sz = n; root = inf; find_root(1, 0); dfs(root); outn(ans); }