Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

收藏待读

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

改進 GAN 除了使用更複雜的網絡結構和損失函數外,還有其他簡單易行的方法嗎?Uber 的這篇文章或許可以給你答案,將 GAN 與貝葉斯方法相結合,在已經訓練好的 GAN 上增加後處理步驟即可。本文對 Uber 的這篇最新工作進行了簡要介紹,如果對內容感興趣還可以點擊文末的原文鏈接閱讀論文,同時文末還提供了該方法的開源代碼,你可以輕鬆用它來提升自己的 GAN 模型。

更多乾貨內容請關注微信公眾號「AI 前線」(ID:ai-front)

生成對抗網絡(GAN)不僅在 真實感圖像生成圖像恢復 方面取得了令人驚嘆的效果,並且由 GAN 生成的一幅藝術作品也售出了 40 萬美元的價格。

在 Uber,GAN 有大量具有潛力的應用,包括增強機器學習模型與對抗性攻擊的對抗能力,學習交通模擬器,乘車請求或隨時間變化的需求模式,以及 為 Uber Eats 生成個性化的訂單建議

GAN 由兩個互相對抗的部分組成 ,一部分是生成器,一部分是判別器。生成器學習真實數據的分佈,判別器負責需要學習如何區別真實樣本和生成樣本(即假樣本)。大多數研究都致力於改進 GAN 的結構和訓練過程來提高其性能,例如使用 更大的網絡結構 或使用不同的損失函數。

NeurIPS2018 的 貝葉斯深度學習研討會 上,Uber 的一篇論文中提供了一種新的思路:調整判別器用於在完成訓練後從生成器中選擇更好的樣本。該工作提供了一種互補的抽樣方法,Google 和 U.C. Berkeley 在判別器舍選抽樣( Discriminator Rejection Sampling ,DRS)的研究與此方法也具有相同的思路。

Uber 這篇工作以及 DRS 方法的核心思想可歸納為,如何使用已經訓練好的判別器的信息來從生成器中選擇樣本,以保證這些被選擇的樣本儘可能符合真實數據的分佈。通常,在訓練完成後判別器就沒有什麼用了,因為在訓練過程中會將判別器學到的知識編碼到生成器中。然而,生成器往往是不完美的,判別器同時也會含有一些有用的信息,所以上述使用判別器信息來提升已經訓練好的 GAN 的方法是值得一試的。Uber 的研究團隊使用了 Metropolis-Hastings 算法對分佈進行抽樣,並將採用這種方法得到的模型稱為 Metropolis-Hastings GAN,即 MH-GAN。

GAN 重抽樣

GAN 的訓練過程通常被理解為兩種條件之間的博弈,生成器需要儘可能讓判別器產生誤判的概率最大化,而判別器則需要儘可能的對真 1z 實數據和生成數據進行良好的區分。圖 1 展示了這個過程,生成器使得函數值向極小值方向移動(橙色線條),而判別器則向極大值方向移動(紫色線條)。訓練結束後,向生成器輸入不同的隨機噪聲可以得到很方便得到生成樣本。如果可以訓練一個完美的生成器,那麼生成器最終的概率密度函數 pG 應與真實數據的概率密度函數相同。然而,許多現有的 GAN 無法很好地收斂到真實數據的分佈 ,因此從這種不完美的生成器中抽樣會產生看起來不像原始訓練數據的樣本。

這種 pG 的不完美讓我們想到另一種分佈情況:判別器對生成器隱含的概率密度。這種分佈被稱為 pD ,並且它往往都很接近真實的數據分佈 pG 。這是因為訓練判別器是一種比訓練生成器更簡單的任務,因此判別器很有可能包含可以用於校正生成器的信息。如果我們有一個完美的判別器 D 和一個不完美的生成器 G,使用 pD 而不是 pG 作為生成的概率密度函數等價於使用一個新的生成器 G』,並且這個 G』是可以完美地模擬真實數據分佈的,如圖一所示:

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

圖 1:等高線圖展示了 GAN 訓練中的對抗過程,聯合函數的值在極小化和極大化之間交替進行。橙色線條表示生成器 G 的優化過程,紫色線條表示判別器 D 的優化。假設 GAN 的訓練過程結束於圖中(D,G)這一點,此時的 G 未處於最優點,但對於這個 G 來說 D 是最優的。此時,通過從 pD 的分佈中抽樣,可以得到一個能夠完美對數據分佈建模的新的生成器 G’。

即使 pD 的分佈可能與數據更匹配,但若想利用其得到樣本數據並不像直接使用生成器那樣直接。幸運的是,我們可以使用抽樣算法從分佈中產生樣本,一種是舍選抽樣法(Rejection Sampling,也被稱為 Acceptance-Rejection Sampling),一種是馬爾科夫鏈蒙特卡洛法(Markov Chain Monte Carlo,MCMC)。這兩種方法都可以作為一種後處理方法來提高生成器的輸出;之前的判別器舍選抽樣法( Discrimitor Rejection Sampling ,DRS)借鑒了舍選抽樣法的思路,而 MH-GAN 則採用了 Metropolis-Hastings MCMC 方法。

舍選抽樣

很多實際問題中,真實分佈 p(x) 是很難直接抽樣的的,因此,我們需要求助其他的手段來抽樣。既然 p(x) 太複雜在程序中沒法直接抽樣,那麼我們可以設定一個程序可抽樣的分佈 q(x) 比如高斯分佈,然後按照一定的方法拒絕某些樣本,達到接近 p(x) 分佈的目的,其中 q(x) 叫做候選分佈(Proposal Distribution)。

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

圖 2: 舍選抽樣

具體操作如下,設定一個方便抽樣的函數 q(x),以及一個常量 k,使得 p(x) 總在 kq(x) 的下方。(參考上圖)

  • x 軸方向:從 q(x) 分佈抽樣得到 a。

  • y 軸方向:從均勻分佈(0, kq(a)) 中抽樣得到 u。

  • 如果剛好落到灰色區域即 u > p(a),則拒絕,否則接受這次抽樣。

重複以上過程便可得到 p(x) 的近似分佈。該方法兩大挑戰分別是:

  1. k 的值通常是人為經驗設置的,無法確定一個準確的值。若 k 值設置的過大可能導致拒絕率很高,增加無用計算;若 k 值過小則有可能找不到正確的 p(x) 分佈。

  2. 合適的 q(x) 分佈通常很難找到。

在 GAN 中,pD 即為目標分佈對應上述 p(x),pG 為現有的分佈對應上述 q(x)。所以在 GAN 中使用該方法的難點主要來源於 k 值的確定,或因 k 值太小而無法正確抽樣,或因 k 值過大而在高維空間中產生大量的計算。為了解決樣本浪費問題,DRS 啟發式地增加了一個γ調整判別器分數,使得判別器 D 即使是完美的情況下,從分佈中產生的樣本仍能夠與真實樣本存在差異。

更好的途徑:Metropolis-Hastings

Uber 的這篇工作使用了 Metropolis-Hastings(MH)方法,這是馬爾科夫鏈蒙特卡洛法一類方法中的一種。這一類方法被最初是作為舍選抽樣法在高維空間中的代替而發明的,它們通過從候選分佈中多點抽樣得到一個儘可能複雜的概率分佈,然後再對這個概率分佈進行抽樣。MH 包含兩步,第一步是從候選分佈中(例如,生成器)選擇 K 個樣本,然後從 K 中依次選擇一個樣本,決定是接受當前樣本還是根據接受規則保留先前選擇的樣本,如圖 3 所示:

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

圖 3:MH 在馬爾科夫鏈中選擇 K 個樣本,然後根據接受規則對每個樣本作出選擇。這個馬爾科夫鏈最終會輸出最終接受的樣本。對於 MH-GAN 而言,K 個樣本由 G 生成,馬爾科夫鏈的輸出由改進後的 MH-GAN’的 G’產生

MH-GAN 最大的特點是接受概率可以僅由概率密度比值 pD/pG 計算得到,而 GAN’的判別器的輸出恰巧可以計算這個比值!假設 xk 為初始樣本,新的樣本 x’可以通過與當前樣本 xk 的概率 d 計算而被接受。

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

其中,D 是判別器分數,由以下公式得到

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

K 是一個超參數,對其調整可以在速度和置信度之間做出權衡。對於一個完美的判別器 K 趨近於無窮,即 D 的分佈完美的接近了真實數據分佈。

MH-GAN 更多細節

1. 獨立抽樣

噪聲樣本被獨立地輸入生成器,經過 K 次生成得到可以符合 MH 選擇器條件的狀態鏈。獨立的鏈被用於從 MH-GAN 的生成器 G』中獲取多樣本。

2. 初始化

對於 MH 算法,由於初始點的不確定性,大部分情況下算法會經過一段長時的 預燒期 才能開始有效的優化過程,即在開始接受第一個數據點之前會拒絕很大一部分數量的數據點。為了避免這種情況,本文對如何初始化狀態鏈的方法進行了詳細的介紹。在清理和初始化每一條狀態鏈時,可以使用真實數據的採樣結果對狀態鏈進行優化。在遍歷了整個狀態鏈之後,如果沒有一個數據被接受,MH-GAN 會從生成樣本中重新開始抽樣,從而確保真實數據中的樣本不被輸出。值得注意的是,MH-GAN 不需要真實的樣本進行初始化,只需要它所對應的判別器分數即可。

3. 校準

實際上,得到完美的 D 是不可能的,但是通過校準步驟可以達到相對完美的程度。另外,完美判別器的假設也不一定就真如它看起來那麼好用。因為判別器僅對生成器和最初的真實數據進行評價,它只需要對來自生成器和真實數據分佈的達到精確判別就可以。在一般的 GAN 訓練中,一般不需要嚴格的要求判別器 D 的值達到一個確定的邊界。但是 MH 算法需要從概率密度比方面對這個值進行良好的校準,從而得到正確的接受比。MH-GAN 使用 10% 的訓練數據作為隨機測試集,使用保序回歸的方法對判別器 D 進行調整。

1D 和 2D 高斯結果

Uber 在論文中使用了一些小例子對 MH-GAN 和 DRS 方法進行了比較,其中真實數據來源於四個單變量的高斯模型的混合結果。通過 pG 的概率密度圖可以看出普通的 GAN 存在的通病,它們的生成結果都缺失了一種模式(如圖 4 所示)。但是,不使用γ校正 DRS 和 MH-GAN 則能良好的還原混合模型,而使用γ進行調整的 DRS 不能還原原始分佈。然而,與使用γ進行調整的 DRS 方法相比,不使用γ的 DRS 方法在第一次接受之前抽樣的數量增加了一個數量級。

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

圖 4:圖中真實數據來自於四個高斯模型組成的 GMM,可以看出生成器的概率密度分佈確實了一個模式。MH-GAN 和不使用γ的 DRS 能夠產生該模式,儘管在第一次接受之前後者需要大量的抽樣數據。 大部分文獻 都喜歡用 5*5 的 2D 高斯模型作為一個簡單的例子進行簡單演示,Uber 也使用了這樣的 2D 模型對基礎 GAN、DRS、MH-GAN 在不同訓練階段下的情況進行了比較,如圖 5 所示。所有的方法都採用了一個 4 層全連接卷積神經網絡,使用線性整流函數(ReLU)作為激活函數,以及一個 100 維的隱層和一個維度為 2 的噪聲向量。從視覺效果上來講,相較於基礎 GAN 的 DRS 取得了明顯的提升,但是它的結果還是更接近基礎 GAN 而不是真實數據。MH-GAN 可以模擬出所有 25 種模式並且從視覺效果上來講更接近於真實數據。定量角度講,MH-GAN 相較於其他方法具有更小的 JS 散度

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

圖 5:上圖是 25 種高斯模型的 2D 分佈情況。相較於基礎 GAN,儘管 DRS 的樣本點更集中於模式周圍,但它缺失的一些模式上看起來與前者很相似,而 MH-GAN 則與真實數據更為相似。下圖展示 MH-GAN 具有更小的 JS 散度。

在 CIFAR-10 和 CelebA 上的結果

這部分內容主要展示了 MH-GAN 在真實數據上的效果,分別測試了選取使用了 梯度懲罰DCGANWGAN 作為基礎 GAN 的結果。在圖 6 的表格中展示了校準後的 MH-GAN 的感知分數( Inception Socre )。

感知分數會完全忽略真實數據而只是用生成的圖像進行評價,它需要將生成圖像傳入在 ImageNet 上預訓練好的 感知分類器 中,感知分數會對輸入圖像屬於某個詳細類的置信度和預測類別的多樣性進行測量。儘管感知分數存在缺陷,但它仍被廣泛用於與其他工作進行比較。

基本上校準後的 MH-GAN 比其他方法都可以取得更好的效果,但是在整個訓練過程中這種優勢並不是一直存在的。對於這種情況的一個解釋是,對於某一輪的迭代,判別器的分數與理想的判別器分數存在巨大差異,從而導致了接受概率缺乏準確性。

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

Uber 提出基於 Metropolis-Hastings 算法的 GAN 改進思想

圖 6:在 CIFAR-10 和 CelebA 上的感知分數,值越高表示效果越好。表格中的數據是第六十次迭代後的結果。

未來工作

MH-GAN 是一種提升 GAN 生成器的簡單方法,該方法使用 Metropolis-Hastings 算法作為一個後處理步驟。在模擬數據和真實數據上 MH-GAN 都表現除了超越基礎 GAN 的效果,與最近提出的 DRS 方法相比 MH-GAN 也更具有優勢。目前該方法僅在較小的數據庫和網絡上進行了驗證,下一步 Uber 計劃將該方法用於更大的數據庫和更先進的網絡。將 MH-GAN 方法擴展到大規模數據庫和 GAN 的途徑是非常簡單粗暴的,因為僅需要額外提供判別器分數和生成器產生的樣本就可以!

此外,使用 MCMC 算法提升 GAN 的思想也可以擴展到其他更高效的算法上,例如漢密爾頓蒙特卡洛方法。如果想獲取關於 MH-GAN 的更多細節和圖表可以閱讀論文: Metropolis-Hastings Generative Adversarial Network ,如果想復現該工作,Uber 提供了該方法基於 Pytorch 的 開源代碼

閱讀英文原文: https://eng.uber.com/mh-gan/

原文 : InfoQ

相關閱讀

免责声明:本文内容来源于InfoQ,已注明原文出处和链接,文章观点不代表立场,如若侵犯到您的权益,或涉不实谣言,敬请向我们提出检举。