这一节我们要用梯度下降来对逻辑回归建模,用数学的话来说就是最小化逻辑回归Cost Function。这可以说是对之前几节内容的一次综合应用。在讲之前你可以先复习一下链式求导法则,然后我们再列一下要用到的一些求导公式:
对于y=x^n 则有 \frac{\mathrm{d} y}{\mathrm{d} x}=nx^{n-1}
对于y=e^x 则有 \frac{\mathrm{d} y}{\mathrm{d} x}=e^x
对于y=\ln{x} 则有 \frac{\mathrm{d} y}{\mathrm{d} x}=\frac{1}{x}
以及一些求导法则

以及推论:

那好,我们先看一下之前对逻辑回归的一些公式定义:

假设函数:

\hat{y}=sigmoid(\mathbf{w}^\mathrm{T}x+b)

Loss Function:

L(\hat{y},y)=-(y\log{\hat{y}}+(1-y)\log(1-\hat{y}))

Cost Function:

J(w,b)=\frac{1}{m}\sum_{i=1}^{m}L({\hat{y}}^{(i)},{y}^{(i)})

之前说过逻辑回归的lable可以有多个,则对应的参数w也是一个向量,有多个值。之前我们为了讲解方便都把w认为是一个值,这次我们假设lable有两个,x1,x2。 则对应的w也有两个值 w1,w2。

首先我们用之前学过的知识来画一下Loss Function的计算图。

现在我们想要求解的是Loss Function对w1,w2,b的偏导数,这里我们用后向传播的方法来逐步求解。第一步我们来求解\frac{\mathrm{d} L}{\mathrm{d} {\hat{y}}}。根据之前列出的求导公式和规则,可以得到:
\frac{\mathrm{d} L}{\mathrm{d} {\hat{y}}}=-\frac{y}{\hat{y}}+\frac{1-y}{1-\hat{y}}
然后我们来求解\frac{\mathrm{d} \hat{y}}{\mathrm{d} {z}}:
\frac{\mathrm{d} \hat{y}}{\mathrm{d} {z}}=\frac{e^{-z}}{{(1+e^{-z})}^2}=\hat{y}(1-\hat{y})
接着求解\frac{\mathrm{d} z}{\mathrm{d} {w1}},\frac{\mathrm{d} z}{\mathrm{d} {w2}},\frac{\mathrm{d} z}{\mathrm{d} {b}}:
\frac{\mathrm{d} z}{\mathrm{d} {w1}}=x1
\frac{\mathrm{d} z}{\mathrm{d} {w2}}=x2
\frac{\mathrm{d} z}{\mathrm{d} {b}}=1
分步求解完成后,我们可以应用链式求导法则来求解\frac{\mathrm{d} L}{\mathrm{d} {w1}},\frac{\mathrm{d} L}{\mathrm{d} {w2}},\frac{\mathrm{d} L}{\mathrm{d} {b}}
最终我们得到(w1,w2,b)的梯度向量:
\begin{pmatrix} x1(\hat{y}-y),x2(\hat{y}-y),(\hat{y}-y) \end{pmatrix}

上边都是对Loss Function的计算,它是针对一条训练记录的,我们需要求解的是针对整个训练集的Cost function。目标是最小化Cost Funcion。不过有了上边的结果很简单,因为根据求导规则的推论3,其实就是对每条记录的Loss Function进行求导并累加,最后除以m。

有了这些我们就可以开始编程来实现了,步骤大概是这样:


w1=random(),w2=random(),b=random();// 给我们假设方程随机赋值
dw1,dw2,db=0; //梯度初始化
J=0;J_last=0;//CostFunction 初始化
rate=0.2; //步长系数
do{
    J_last=J;
    J=0;
    w1-=rate*dw1/trainingSet.size()
    w2-=rate*dw2/trainingSet.size()
    b-=rate*db/trainingSet.size()
    foreach item in trainingSet{
        J+=LossFunction(item)
        dw1+=item.x1 * (y_hat-item.y)
        dw2+=item.x2 * (y_hat-item.y)
        db+=y_hat-item.y
    }
while(J<=J_last)

发表评论

电子邮件地址不会被公开。 必填项已用*标注

%d 博主赞过: