陳凱
用梯度下降的方法來實現(xiàn)線性回歸,是一種很經(jīng)典的機器學(xué)習(xí)算法。然而,在基礎(chǔ)教育階段,由于受學(xué)生自身數(shù)學(xué)水平和信息技術(shù)水平的限制,他們對這種算法的基本原理以及實現(xiàn)回歸過程的程序代碼的學(xué)習(xí)和理解,難度還是比較大的。筆者曾經(jīng)瀏覽過不少線性回歸方面的資料,發(fā)現(xiàn)學(xué)習(xí)路徑頗為陡峭。本篇文章試著從一個小游戲入手,一邊玩一邊想,逐層鋪墊,緩慢進(jìn)階,希望能讓學(xué)習(xí)者避免迷失在諸多概念名詞和繁雜程序代碼之中,真正體驗到線性回歸算法的精髓。
● 冰糖葫蘆手工串
先來看游戲的玩法。有幾個糖葫蘆散落在桌面上,要求不改變糖葫蘆的位置,用一根竹簽將它們串起來,游戲一開始,竹簽只是平放在桌面邊緣,并沒有串起任何一個糖葫蘆,如圖1所示。
所謂的桌面,實際上是一個坐標(biāo)軸,X軸和Y軸的范圍都是0到10;所謂的糖葫蘆,是四個面積比較大的點。實現(xiàn)串糖葫蘆游戲的Python代碼非常簡單,如圖2所示。
坐標(biāo)軸上的這四個點的markersize參數(shù)是40,所以看上去就相當(dāng)大,可以想見,這個參數(shù)越大,糖葫蘆也越大,游戲也越簡單。糖葫蘆所在的坐標(biāo)直接寫在了plot函數(shù)的參數(shù)里,為代碼簡單清晰起見,這里并沒有引入隨機數(shù)。
所謂的竹簽,是通過y=t1+t2*x這個一元一次方程產(chǎn)生的直線,其中t1是方程的常數(shù),t2是一次項系數(shù)。如果輸入t1為0.9,輸入t2為0.6,那么竹簽就串到了兩個糖葫蘆,如上頁圖3所示。
試過幾次就可以發(fā)現(xiàn),t1和t2兩個數(shù)字的作用大不相同:前者決定了竹簽的位置,“靠上還是靠下”與“靠左還是靠右”其實是一回事;后者決定了竹簽的傾斜度。如果想讓竹簽往上擺一些,則t1要增加;如果想讓竹簽擺得平一些,則t2要減少。不妨再試一下,輸入t1為3,輸入t2為0.2,竹簽成功地串起了三個糖葫蘆,如上頁圖4所示。
多次嘗試之后,就能體會到t1和t2兩個數(shù)字的變化與最終直線形態(tài)之間的微妙關(guān)系。
● 竹簽離得有多遠(yuǎn)
剛才是用人腦來判斷竹簽的位置t1和傾角t2,那么,怎么讓機器判斷位置和傾角呢?方法就是“看了再試,試了再看”。
例如,一開始,竹簽是平躺著的,竹簽到每個點的縱向的距離(為簡單起見,這里暫不考慮橫向的距離)是可以計算出來的,如圖5所示,其中左圖顯然離開理想結(jié)果還很遙遠(yuǎn),右圖已經(jīng)比較接近目標(biāo)了。
為了能夠計算出竹簽到糖葫蘆的距離,以評估竹簽與理想目標(biāo)之間的差距,可以將代碼稍微修改一下。由于糖葫蘆可能在竹簽上面,也可能在竹簽下面,在計算距離時,數(shù)值可能是正,也可能是負(fù),所以縱向距離統(tǒng)一進(jìn)行二次方的運算(其實取絕對值也是一樣的)。又因為總共有四個糖葫蘆,所以要除以4,得到一個縱向距離的平均數(shù)。修改后的程序代碼如圖6所示。
e = (e1+e2+e3+e4)/4/2這段代碼,就是將糖葫蘆中心點和竹簽的縱向距離的平均值賦值給變量e,其中唯一難理解的是,為什么除以4取得平均值后又要除以2,其實這是為了讓后續(xù)的求導(dǎo)公式更加便捷,若是學(xué)習(xí)者學(xué)習(xí)過隱藏在線性回歸算法之后的數(shù)學(xué)原理,那么就能更清楚地知道這里為何要除以2。實際上,因為變量e的值只是用于在運算過程中指示竹簽位置和傾角離開理想值的差距,所以是否除以2,其實是無所謂的。假如說任務(wù)是要讓竹簽盡可能靠近糖葫蘆的中心,那么就需要觀察e值是否能收斂于某個值,然而,本文任務(wù)只要求把糖葫蘆串起來即可,因此后續(xù)的程序代碼中并不需要用到e變量。
上述代碼中,t1和t2的初始值都是0,運行后,發(fā)現(xiàn)得到的e的值是9.15227,這個數(shù)字顯然太大了,計算結(jié)果和人眼的直觀感受是符合的,所以需要改變t1和t2的值使得竹簽距離糖葫蘆更近一些。
● 竹簽需要重新擺
為了讓竹簽有可能串起更多糖葫蘆,就要重新調(diào)整竹簽的位置和傾角,其實就是更改t1和t2的值。但計算機怎么知道應(yīng)該如何調(diào)整呢?
首先來看竹簽的位置,竹簽究竟是往上放還是往下放?應(yīng)該需要調(diào)整多少距離?想象一下,如果大部分糖葫蘆都在竹簽上方,那么就要往上挪,反之就是往下挪,離開越遠(yuǎn),挪的距離就越多,程序中涉及的公式如上頁圖7所示。
這個公式是用微積分的方法推演出來的,然而就算是直觀上也是可以理解的,當(dāng)竹簽平躺著的時候,得到的d1的值是4.18,這表示竹簽要往上方移動4.18個單位,如果d1的數(shù)值是負(fù)數(shù),則表示要將竹簽往下方移動。
傾角的調(diào)整要復(fù)雜一些,因為這和每個糖葫蘆的橫向位置有關(guān),涉及的公式如上頁圖8所示。
這同樣是用微積分求導(dǎo)的方法推演出來的,如果不想深入了解相關(guān)數(shù)學(xué)推導(dǎo)過程,那只需要驗證公式的合理性就可以了。當(dāng)竹簽平躺著的時候,得到的值是23.65975。如果是正數(shù),表示要把竹簽傾角增大,如果是負(fù)數(shù),則表示將傾角變小。不過即便是直觀看圖9,也是能感受到公式的意義的,因為離開坐標(biāo)軸原點越遠(yuǎn),點的坐標(biāo)值本身的權(quán)重也就越大。
然而,按公式計算后,d1和d2這兩個數(shù)值都太大了,實際操作時矯枉過正,于是還要對這兩個值再乘上步長系數(shù)0.05,此數(shù)可大可小,要根據(jù)實際運行情況來調(diào)整。系數(shù)太小,則需要很多次調(diào)整才能達(dá)成目標(biāo);系數(shù)太大,則會在理想結(jié)果周圍來回跳躍。
最后,將d1和d2的修正值疊加到原先的t1和t2值上,將循環(huán)次數(shù)增加到200次,每一次循環(huán)中,做的都是同樣的事:用當(dāng)前的t1和t2計算出直線擺放姿態(tài),然后將運行效果比對預(yù)期效果后得到調(diào)整值d1和d2,再用d1和d2修正t1和t2值,這種方法稱為迭代法。圖10、圖11是完整的代碼和運行結(jié)果。
這個例子本身形象直觀,而代碼也比較簡單,若去掉繪圖相關(guān)的代碼,核心代碼僅有十多行,其中隱藏了超出學(xué)生當(dāng)前水平的知識點,尤其是數(shù)學(xué)微積分和線性代數(shù)有關(guān)的知識技能,但同時又留下了繼續(xù)深入學(xué)習(xí)的路徑指引,為學(xué)習(xí)者提供了真切的機器學(xué)習(xí)算法的實踐體驗。