快速求區間和的有趣演算法——樹狀陣列
好久沒寫東西,感覺有寫些什麼的必要了。
樹狀陣列雖然聽起來名字高大上,但是不是很難(字首和是名字高大上,卻水得像海洋)
樹狀陣列在單純的查詢一個區間的和和修改某一個數的效率要超過線段樹哦!樹狀陣列最差時間複雜度為O(logn),而線段樹的時間複雜度一直保持O(logn),且線段樹的空間複雜度是樹狀陣列的4倍。
But:樹狀陣列只是線段樹的一個辣雞版本(雖然在某些方面比線段數快一點點),使用樹狀陣列很大的一個原因是樹狀陣列十分好寫,且非常好維護。但是它只能處理可以用字首和或差分來解決的題目,像是求(l,r)之間的最大值,樹狀陣列就會Game Over。
為什麼叫樹狀陣列呢?因為它長得像樹一樣(廢話),就像這個樣子:
表示我的畫圖技術和畫圖軟體都爛炸了
現在假如有n個數,存在A數組裡,用C陣列當樹狀陣列,從A[1]開始存入,一直存到A[n],然後順便把C陣列初始化。(一會兒解釋為什麼不從A[0]開始存)
通過看圖,可以得到這麼一個結論:
C1 = A1 C2 = A1+A2 C3 = A3 C4 = A1+A2+A3+A4 C5 = A5 C6 = A5+A6 C7 = A7 C8 = A1+A2+A3+A4+A5+A6+A7+A8
現在找找規律!
好吧,是不是有感覺了但是表達不出來?
再處理一下,把C陣列的下標用二進位制表示出來
C(1) = A1 C(10) = A1+A2 C(11) = A3 C(100) = A1+A2+A3+A4 C(101) = A5 C(110) = A5+A6 C(111) = A7 C(1000) = A1+A2+A3+A4+A5+A6+A7+A8
我們把這些下標的二進位制從後面往前面看,看到出現一個1為止:
C(1) = A1 C(10) = A1+A2 C(1) = A3 C(100) = A1+A2+A3+A4 C(1) = A5 C(10) = A5+A6 C(1) = A7 C(1000) = A1+A2+A3+A4+A5+A6+A7+A8
最後一個處理,把讀後的二進位制下標再轉換成十進位制:
C(1) = A1 C(2) = A1+A2 C(1) = A3 C(4) = A1+A2+A3+A4 C(1) = A5 C(2) = A5+A6 C(1) = A7 C(8) = A1+A2+A3+A4+A5+A6+A7+A8
現在絕對能看懂了(自信滿滿)
好了,“顯然”可以看出,假如原來C陣列的下標為a,現在的下標為b,那麼這個C[a]就對應著從A[b]和它前面總共b個數的和,或者可以說,對應著從A[a-b+1]到A[b]的數的和。(原本很簡單的東西為啥講出來這麼麻煩?)
有人要說了:講這麼半天,你也沒有告訴我怎麼初始化(玩)樹狀陣列。
好吧好吧,現在開始講(che)解(dan)。
還是從一道模板題開始最好了QvQ!(然後從一道模板題結束......)
題目描述
如題,已知一個數列,你需要進行下面兩種操作:
1.將某一個數加上x
2.求出某區間每一個數的和
輸入輸出格式
輸入格式:
第一行包含兩個整數N、M,分別表示該數列數字的個數和操作的總個數。
第二行包含N個用空格分隔的整數,其中第i個數字表示數列第i項的初始值。
接下來M行每行包含3個整數,表示一個操作,具體如下:
操作1: 格式:1 x k 含義:將第x個數加上k
操作2: 格式:2 x y 含義:輸出區間[x,y]內每個數的和
輸出格式:
輸出包含若干行整數,即為所有操作2的結果。
輸入輸出樣例
輸入樣例:
5 5
1 5 4 2 3
1 1 3
2 2 5
1 3 -1
1 4 2
2 1 4
輸出樣例:
14
16
說明
時空限制:1000ms,128M
資料規模:
對於30%的資料:N<=8,M<=10
對於70%的資料:N<=10000,M<=10000
對於100%的資料:N<=500000,M<=500000
這道題明顯是要用樹狀陣列做嘛!(笑)
題目的意思非常好懂,就是n個數,m個操作。操作分兩種,一種是查詢區間和,一種是修改(增加)第幾個數的值。
那麼開始碼程式碼吧,先從主函式(main)開始:
因為A數組裡的所有值在C數組裡都能查詢到,所以並不需要建一個A陣列,只需要讀一個數,然後把C陣列更新一下便好了。
程式碼如下:
cin>>n>>m;//分別是n個數m個操作 for(int i=1;i<=n;i++){ int v; cin>>v; update(i,v);//這個函式是在序列(假想的A陣列)第i個位置加上v,因為初始都是零,所以相當於初始化,這個函式的實現後面講 }
讀入之後,就是m個操作了:
for(int i=1;i<=m;i++){ int k,a,b; cin>>k>>a>>b;//k是模式(題中有),a和b下面要用到 if(k==1) update(a,b);//在序列的第a個位置加上b else//如果不是在某個位置上加一個數,就是求區間和啦 cout<<sum(b)-sum(a-1)<<endl;//輸出區間和,這個看不懂不要緊,後面講(怎麼什麼都後面講poq) }
好吧,現在主函式除了return 0;以外都寫完了,現在就到了講(che)最難的update和sum函數了(其實還有一個lowbit函式)
先回到那個
圖:
唉,醜陋不堪!
再把最開始得到的結論搬出來:
C1 = A1 C2 = A1+A2 C3 = A3 C4 = A1+A2+A3+A4 C5 = A5 C6 = A5+A6 C7 = A7 C8 = A1+A2+A3+A4+A5+A6+A7+A8
還有這個:
C(1) = A1 C(10) = A1+A2 C(11) = A3 C(100) = A1+A2+A3+A4 C(101) = A5 C(110) = A5+A6 C(111) = A7 C(1000) = A1+A2+A3+A4+A5+A6+A7+A8
舉個栗子,假如我們修改了A[3]的值,C陣列中的哪些元素需要修改呢?
通過看圖和看圖後得到的結論,“顯然“就是包含A[3]的C陣列的元素,或者說是C[3]和它的”祖先”(反正人們說樹都喜歡用祖先這個詞),也就是C[3],C[4]和C[8]。
因為你不能給計算機一個xxx.jpg然後讓它自動修改需要修改的C陣列的元素是不是?所以現在,得到一個遞推式來自動處理顯得很有必要了。
現在你不用自己找了,因為已經有人幫你找好了。
我們設x下標的二進位制從後面往前面看,看到出現一個1時,我們看過的二進位制為lowbit(x),如,3的二進位制是11,那麼lowbit(3)便是1了,又如4的二進位制是100,那麼lowbit(4)就是100了。
如果我們把3加上lowbit(3),得到4,再把4加上lowbit(4),就得到我們要的8了,這樣,就愉快地把要修改的C陣列的元素全部找到了。
先把lowbit函式給寫了吧:
int lowbit(int x){ return x& (-x); }
至於這個lowbit裡面是怎麼回事,因為涉及到補碼什麼的,就不講了,反正也很好記๑乛◡乛๑
不過不能一直加下去吧?邊界條件很好找,就是x不會超過n(顯然易見的)。
現在把update函式也放出來:
void update(int x,int v){ while(x<=n){//邊界條件 c[x]+=v;//將要更新的C陣列的元素加上v x+=lowbit(x);//下一個元素 } }
之前有一個問題,就是為什麼A陣列不從0開始,因為lowbit(0)等於0,那麼就會永遠達不到邊界條件,也就是x永遠也不會達到n,總之會無限迴圈下去,就炸了,炸了!
好啦,現在唯一沒有講(che)的是主函式中的這句話了:
cout<<sum(b)-sum(a-1)<<endl;
很簡單,sum(x)函式是計算序列中第1個數到第x個數的和的函式(繞暈),和字首和的思想相同,若想求第a個數到第b個數的和,只需要求第1個數到第b個數的和減第1個數到第a-1個數的和即可
那麼是時候講(che)sum函式的構造了!
假如要求序列中第1個數到第7個數的和該怎麼弄?看看錶就明白了——>C[7]+C[6]+C[4],再拆成二進位制C(111)+C(110)+C(100)。那麼假如要求序列中第1個數到第6個數的和呢?再看一下表C[6]+C[4],再拆成二進位制C(110)+C(100)。
可以看出來,要求第1個數到第x個數的和,只需要從x開始向下遞推,然後用一個變數將一堆C[x]加起來,就可以得到第1個數到第x個數的和了,邊界條件也是“顯然易見”的,那就是x>0或x>=1。
話不多說,上程式碼:
int sum(int x){ int res=0;//儲存一堆C[x]的和的變數 while(x>0){//邊界條件 res+=c[x];//加上...... x-=lowbit(x);//下一個 } return res; }
這樣,這道題就可以AC了!
附上完整程式碼:
#pragma GCC optimize(3) #include<bits/stdc++.h> using namespace std; static int n,m; static int c[500005]; inline int lowbit(int x){ return x& (-x); } void update(int x,int v){ while(x<=n){ c[x]+=v; x+=lowbit(x); } } int sum(int x){ int res=0; while(x>0){ res+=c[x]; x-=lowbit(x); } return res; } int main(){ cin>>n>>m; for(int i=1;i<=n;i++){ int v; cin>>v; update(i,v); } for(int i=1;i<=m;i++){ int k,a,b; cin>>k>>a>>b; if(k==1) update(a,b); else cout<<sum(b)-sum(a-1)<<endl; } return 0; }
請無視我手動開的O3和C++17中全域性變數必須加的static......