讓智慧體主動互動,DeepMind提出用元強化學習實現因果推理
論文:https://arxiv.org/pdf/1901.08162.pdf
發現和利用環境中的因果結構是智慧體面臨的一大關鍵挑戰。這裡我們探索了是否可通過元強化學習來實現因果推理(cause reasoning)。我們使用無模型強化學習訓練了一個迴圈網路來求解一系列包含因果結構的問題。我們發現,訓練後的智慧體能夠在全新的場景中執行因果推理,從而獲得獎勵。智慧體可以選擇資訊干預、根據觀察資料得出因果推論以及做出反事實的預測。儘管也存在已有的形式因果推理演算法,但我們在這篇論文中表明這樣的推理可以由無模型 強化學習 產生,並提出這裡給出的更多端到端的基於學習的方法也許有助於在複雜環境中的因果推理。通過讓智慧體具備執行——以及解釋——實驗的能力,本研究也能為強化學習中的結構化探索提供新的策略。
1 引言
很多機器學習演算法的根基都是發現數據中的相關模式。儘管這種方法足以應對許多領域(Krizhevsky et al., 2012; Cho et al., 2014),但有時候我們感興趣的問題具有固有的因果性質。在回答「吸菸是否導致癌症?」或「這個人工作被拒的原因是種族歧視嗎?」或「是這個營銷活動導致了銷量上漲嗎?」這些問題時,需要有推理因果的能力。因果推理可能是自然智慧的一大關鍵元件,在人類嬰兒、大鼠甚至鳥類身上都有體現(Leslie, 1982; Gopnik et al., 2001; 2004; Blaisdell et al., 2006; Lagnado et al., 2013)。
有關定義和執行因果推理的形式方法的文獻很豐富(Pearl, 2000; Spirtes et al., 2000; Dawid, 2007; Pearl et al., 2016)。我們研究了能否通過 元學習 實現這樣的推理。元學習方法是指直接從資料中學習「學習(或推斷/估計)過程」自身。人類智慧也與類比模型(Grant et al., 2018)有密切聯絡(Goodman et al., 2011),這種模型是直接從環境中學習因果結構,而沒有一個預先設計的形式理論。
我們特別採用了之前的研究(Duan et al., 2016; Wang et al., 2016)引入的「元強化學習」,其中使用無模型強化學習(RL)方法訓練了一個基於迴圈神經網路(RNN)的智慧體。通過在多種類別的結構化任務上進行訓練,這個 RNN 變成了一個能泛化到取自類似分佈的新任務上的學習演算法。在我們的案例中,我們在一個任務分佈上進行了訓練,其中每一個任務都有一個不同的因果結構作為支撐。我們關注的是能最好地隔離相關問題的抽象任務:當不向智慧體明確提供因果概念時,元學習可否產生能執行因果推理的智慧體。
元學習能端到端地學習,通過分攤計算而提供可擴充套件性的優勢,該演算法有望找到最適用於所需因果推理型別的因果結構的內部表徵(Andrychowicz et al., 2016; Wang et al., 2016; Finn et al., 2017)。我們重點關注強化學習的原因是我們感興趣的不僅是讓智慧體根據被動觀察學習因果,而且也能通過與環境的主動互動來學習(Hyttinen et al., 2013; Shanmugam et al., 2015)。
2 問題說明與方法
我們研究了三種明顯不同的資料設定——觀察的、有干預的和反事實的。這些不同設定測試的是不同型別的推理。
-
在觀察式設定中(實驗 1),智慧體僅能從環境中獲取被動的觀察資料。這種型別的資料可讓智慧體推斷相關性(關聯性推理/associative reasoning),並且還能根據環境的結構推斷因果關係(因果性推理/cause-effect reasoning)。
-
在有干預的設定中(實驗 2),智慧體可通過設定某些變數的值以及觀察對其它變數的影響而在環境中採取行動。這種型別的資料有助於對因果關係的估計。
-
在反事實的設定中(實驗 3),智慧體首先有機會通過互動來了解環境的因果結構。在 episode 的最後一步,它必須回答一個反事實的問題,該問題的形式為「如果在之前的時間步驟進行不同的干預會怎樣?」
接下來我們將使用圖模型框架(Pearl, 2000; Spirtes et al., 2000; Dawid, 2007)對這些設定以及每種設定中可能的推理模式進行形式化。隨機變數用大寫字母標註,它們的值用小寫字母標註。
2.1因果推理
隨機變數之間的因果關係可以使用因果貝葉斯網路(CBN,詳見補充材料)表示。CBN 是一種有向無環圖模型,既能表示獨立關係,也能表示因果關係。每個節點 X_i 對應於一個隨機變數,並且聯合分佈 p(X_1,...,X_N) 是根據每個節點 X_i 的父節點 pa(X_i) 通過求每個節點 X_i 的條件分佈的積而得到的,即:
邊帶有因果語義資訊:如果存在一條從 X_i 指向 X_j 的路徑,則 X_i 就是 X_j 的一個潛在原因。有向路徑也被稱為因果路徑。X_i 對 X_j 的因果影響是給定限定在僅有因果路徑的 X_i 時 X_j 的條件分佈。
圖 1a 給出了一個 CBN 示例,其中 E 表示一週的鍛鍊小時數,H 表示心臟健康情況,A 表示年齡。E 對 H 的因果影響是限定在路徑 E→H 的條件分佈,即不包括路徑 E←A→H。變數 A 被稱為混雜變數(confounder),因為它將因果影響與非因果的統計影響混雜到了一起。只是通過 p(H|E) 基於鍛鍊水平觀察心臟健康狀況(關聯性推理)不能解答鍛鍊水平改變是否會造成心臟健康變化的問題(因果性推理),因為總是存在這樣的可能性:這兩者之間的關聯源自共有的年齡混雜變數。
圖 1:(a)有一個混雜變數的因果貝葉斯網路(CBN):年齡(A)和鍛鍊身體(E)對健康(H)的影響。(b)受干預 CBN,通過將 p(E|A) 替換成一個 δ 分佈 δ(E−e) 而對前面的 CBN 進行了修改,條件分佈 p(H|E,A) 和 p(A) 保持不變。
2.2元學習
元學習是一類範圍廣泛的方法,其從資料中學習的是學習演算法本身的各個方面。深度學習演算法的很多單個元件都可通過元學習成功得到,包括優化器(Andrychowicz et al., 2016)、初始引數設定(Finn et al., 2017)、度量空間(Vinyals et al., 2016)、外部記憶的使用(Santoro et al., 2016)。
按照(Duan et al., 2016; Wang et al., 2016)的方法,我們將整個學習演算法引數化為了一個迴圈神經網路(RNN),然後我們使用無模型強化學習來訓練這個 RNN 的權重。這個 RNN 是在一個寬廣的問題分佈上訓練的,其中每個問題都需要學習。當以這種方式訓練時,RNN 可以實現能有效求解訓練分佈的同分布或相近分佈中全新的學習問題(更多細節請參閱補充材料)。
通過無模型強化學習學習 RNN 的權重可被視為學習的「外環(outer loop)」。外環將 RNN 的權重整合進一個「內環」學習演算法中。這個內環演算法會在 RNN 的啟用動態中一直執行,即使當該網路的權重被凍結時也能繼續學習。這個內環演算法也可以與用於訓練它的外環演算法有非常不同的性質。比如,在之前的工作中,這種方法曾被用於協調多臂賭博機問題中的探索-利用權衡(Duan et al., 2016),也曾被用於學習能動態調整自身學習率的演算法(Wang et al., 2016; 2018)。我們在本研究中探索了獲取可感知因果的內環學習演算法的可能性。
3 任務設定和智慧體架構
在我們的實驗中,智慧體在每個 episode 中都會和一個不同的 CBN 互動,這些 CBN 由 N 個變數定義。CBN 的結構是從可能的無環圖空間中隨機取出的,其取出方式的限定條件將在後續小節說明。
每個 episode 包含 T 個步驟,可分為兩個階段:資訊階段和測驗階段。資訊階段對應於前 T-1 個步驟,讓智慧體可通過與 CBN 互動或被動觀察 CBN 的樣本來收集資訊。智慧體有望使用這些資訊來推斷 CBN 的連線方式和權重。測驗階段對應於最後一個步驟,要求智慧體利用其在資訊階段收集到的因果資訊,從而在存在隨機外部干預時選擇出值最高的節點。
智慧體架構和訓練
我們使用了一個長短期記憶(LSTM)網路(Hochreiter and Schmidhuber, 1997)(有 192 個隱藏單元)。在每個時間步驟 t,該網路都接收一個包含 的連線向量作為輸入,其中,o_t 是觀察,a_(t-1) 是前一個動作(是一個 one-hot 向量),r_(t-1) 是獎勵(是單個實數值)。
其輸出是作為 LSTM 的隱藏狀態的線性投射而計算的,是一組策略 logits(其維度等於可用動作的數量),加上一個標量基線。這個策略 logits 會由一個 softmax 函式變換,然後再被取樣以給出一個所選的動作。
學習使用了非同步優勢 actor-critic(Mnih et al., 2016),其損失函式包含三項——策略梯度、基線成本和一個熵成本。基線成本由相對於策略梯度成本 0.05 進行加權。熵成本的權重是在訓練過程中從 0.25 到 0 退火式衰減。優化由 RMSProp 完成,其 ε=10^-5,動量=0.9,衰減率=0.95。學習率從 9×10^−6 到 0 退火式衰減,折扣因子為 0.93。除非另有說明,訓練完成要執行 1×10^7 步,使用了批大小為 1024 的分批式環境。
對於所有實驗,在訓練完成之後,都在一個留存測試集上對智慧體進行測試,學習率設為零。
4 實驗
圖 2:實驗 1。智慧體根據觀察資料執行因果性推理。a)實驗中測試的智慧體得到的平均獎勵。b)在有外部干預的節點上根據至少存在或不存在一個父節點(分別表示為 Parent 和 Orphan)而劃分的表現。c)一個測試 CBN 的測驗階段。綠色和紅色邊分別表示 +1 和 -1 的權重。黑色表示被幹預的節點,綠色和紅色節點分別表示該節點的值為正和負,白色表示為零。藍色圓圈表示該智慧體的選擇。
圖 3:實驗 2。智慧體根據干預資料執行因果性推理。a)實驗中測試的智慧體得到的平均獎勵。b)在有外部干預的節點上根據存在或不存在未被觀察的混雜變數(分別表示為 Conf. 和 Unconf.)而劃分的表現。c)一個測試 CBN 的測驗階段。
圖 6:實驗 3。智慧體執行反事實推理。a)實驗中測試的智慧體得到的平均獎勵。b)根據測驗階段中最大節點值是退化的(Deg.)或明顯不同的(Dist.)而劃分的表現。c)一個測試 CBN 的測驗階段。