NDArray自动求导

NDArray可以很方便的求解导数,比如下面的例子:(代码主要参考自https://zh.gluon.ai/chapter_crashcourse/autograd.html

 用代码实现如下:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3 x = nd.array([[1,2],[3,4]])
 4 print(x)
 5 x.attach_grad() #附加导数存放的空间
 6 with ag.record():
 7     y = 2*x**2
 8 y.backward() #求导
 9 z = x.grad #将导数结果(也是一个矩阵)赋值给z
10 print(z) #打印结果
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

[[  4.   8.]
 [ 12.  16.]]
<NDArray 2x2 @cpu(0)>

 

对控制流求导

NDArray还能对诸如if的控制分支进行求导,比如下面这段代码:

1 def f(a):
2     if nd.sum(a).asscalar()<15: #如果矩阵a的元数和<15
3         b = a*2 #则所有元素*2
4     else:
5         b = a
6     return b

数学公式等价于:

这样就转换成本文最开头示例一样,变成单一函数求导,显然导数值就是x前的常数项,验证一下:

import mxnet.ndarray as nd
import mxnet.autograd as ag

def f(a):
    if nd.sum(a).asscalar()<15: #如果矩阵a的元数和<15
        b = a*2 #则所有元素平方
    else:
        b = a
    return b

#注:1+2+3+4<15,所以进入b=a*2的分支
x = nd.array([[1,2],[3,4]])
print("x1=")
print(x)
x.attach_grad()
with ag.record():
    y = f(x)
print("y1=")
print(y)
y.backward() #dy/dx = y/x 即:2
print("x1.grad=")
print(x.grad)

x = x*2
print("x2=")
print(x)
x.attach_grad()
with ag.record():
    y = f(x)
print("y2=")
print(y)
y.backward()
print("x2.grad=")
print(x.grad)
x1=
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>
y1=
[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
x1.grad=
[[ 2.  2.]
 [ 2.  2.]]
<NDArray 2x2 @cpu(0)>
x2=
[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
y2=
[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
x2.grad=
[[ 1.  1.]
 [ 1.  1.]]
<NDArray 2x2 @cpu(0)>

 

头梯度

原文上讲得很含糊,其实所谓头梯度,就是一个求导结果前的乘法系数,见下面代码:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3
 4 x = nd.array([[1,2],[3,4]])
 5 print("x=")
 6 print(x)
 7
 8 x.attach_grad()
 9 with ag.record():
10     y = 2*x*x
11
12 head = nd.array([[10, 1.], [.1, .01]]) #所谓的"头梯度"
13 print("head=")
14 print(head)
15 y.backward(head_gradient) #用头梯度求导
16
17 print("x.grad=")
18 print(x.grad) #打印结果
x=
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>
head=
[[ 10.     1.  ]
 [  0.1    0.01]]
<NDArray 2x2 @cpu(0)>
x.grad=
[[ 40.           8.        ]
 [  1.20000005   0.16      ]]
<NDArray 2x2 @cpu(0)>

对比本文最开头的求导结果,上面的代码仅仅多了一个head矩阵,最终的结果,其实就是在常规求导结果的基础上,再乘上head矩阵(指:数乘而非叉乘)

 

链式法则

先复习下数学

注:最后一行中所有变量x,y,z都是向量(即:矩形),为了不让公式看上去很凌乱,就统一省掉了变量上的箭头。NDArray对复合函数求导时,已经自动应用了链式法则,见下面的示例代码:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3
 4 x = nd.array([[1,2],[3,4]])
 5 print("x=")
 6 print(x)
 7
 8 x.attach_grad()
 9 with ag.record():
10     y = x**2
11     z = y**2 + y
12
13 z.backward()
14
15 print("x.grad=")
16 print(x.grad) #打印结果
17
18 print("w=")
19 w = 4*x**3 + 2*x
20 print(w) # 验证结果
x=
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>
x.grad=
[[   6.   36.]
 [ 114.  264.]]
<NDArray 2x2 @cpu(0)>
w=
[[   6.   36.]
 [ 114.  264.]]
<NDArray 2x2 @cpu(0)>

 

时间: 2024-10-31 15:07:29

NDArray自动求导的相关文章

mfc-MFC中如何使用求导法判断周期数据中的拐点和鞍点

问题描述 MFC中如何使用求导法判断周期数据中的拐点和鞍点 MFC中如何使用求导法判断周期数据中的拐点和鞍点,分别提取它们到新的数据文件 解决方案 http://zhidao.baidu.com/link?url=hFcevJMBxUipTvIYlyYLYpCQsQkLqhwInXMkJ0Qtt65fs1UPgtAasbpQIxQzk4Tz9ZMBWUbk-ao5PGrGV3w6fAoRv73ZYSgLAh0i3IHurTW

mfc-MFC函数求幂之后已经有了原浮点数和幂,是不是就可以求导?

问题描述 MFC函数求幂之后已经有了原浮点数和幂,是不是就可以求导? MFC函数求幂之后已经有了原浮点数和幂,是不是就可以求导?求导的函数公式怎么计算?看了一些资料但是不是很明晰 解决方案 http://www.bianceng.cn/Programming/cplus/201411/46664.htm

c语言- 用C语言实现一元多项式求导

问题描述 用C语言实现一元多项式求导 时间限制 400 ms 内存限制 65536 kB 代码长度限制 8000 B 判题程序 Standard 设计函数求一元多项式的导数. 输入格式:以指数递降方式输入多项式非零项系数和指数(绝对值均为不超过1000的整数).数字间以空格分隔. 输出格式:以与输入相同的格式输出导数多项式非零项的系数和指数.数字间以空格分隔,但结尾不能有多余空格.注意"零多项式"的指数和系数都是0,但是表示为"0 0". 输入样例: 3 4 -5

BP算法双向传,链式求导最缠绵(深度学习入门系列之八)

更多深度文章,请关注:https://yq.aliyun.com/cloud 系列文章: 一入侯门"深"似海,深度学习深几许(深度学习入门系列之一)人工"碳"索意犹尽,智能"硅"来未可知(深度学习入门系列之二)神经网络不胜语,M-P模型似可寻(深度学习入门系列之三)"机器学习"三重门,"中庸之道"趋若人(深度学习入门系列之四)Hello World感知机,懂你我心才安息 (深度学习入门系列之五)损失函数减肥

update-Update 数据表时自动求如何写SQL语句呢(使用SQL Sever)?(设计触发器或存储过程吧)

问题描述 Update 数据表时自动求如何写SQL语句呢(使用SQL Sever)?(设计触发器或存储过程吧) 我有一个学生考试信息表: 考号,姓名,语文成绩,数学成绩,英语成绩,文综成绩,总成绩 (PS:默认各科成绩,总成绩都为 0)在老师登分时只会登入各个科目的成绩,我使用的是SQL Sever数据库,当老师登入各科成绩时(使用Update),如何触发自动求和?_谢谢!_ 解决方案 create trigger trig_学生考试信息表 on 学生考试信息表 for insert as be

link能不能自动求出一段话中最长的重复的两句话?

问题描述 link能不能自动求出一段话中最长的重复的两句话? link能不能自动求出一段话中最长的重复的两句话? 解决方案 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApplication2 { class Program { static void Main(string[

mfc-一个MFC求幂指数函数和求导结合的问题

问题描述 一个MFC求幂指数函数和求导结合的问题 如果C++ 6.0 MFC已经求出了幂指数函数,请问如何进一步求指定浮点数的导数 解决方案 求导和求指数函数是两回事,参考:http://wenku.baidu.com/link?url=IA5Py3jd-tpw-ZMuvtSrW8vlhKr32BTf7DeGhsYUNuMnqWhMw5RPj14r5bcng7bz0sMXuaWQTCvxAeE8t765OsGeuFuB5WhH9Nhyyw9JU3K

Excel自动求平均值的函数公式

在制作表格的过程中,我们可能会用Excel来对数据进行各种运算,如:求和.求差.求积等公式,来完成我们的运算.在前面几课中我们已经基本的讲解了各种运算的函数公式,本篇再来说下在Excel表格中如何求平均值.我们在制作一份成绩表排名的时候,知道了各科成绩,需要求出成绩的平均值,我们该如何来完成呢?下面就看看Word联盟为大家演示吧! 首先,这里是一份成绩表,上面有各门功课的成绩,我们要求出平均分数. ①将光标定位到"平均分"下面一个单元格中,然后点击"插入函数"按钮,

Excel自动求平均值函数公式

  在制作表格的过程中,我们可能会用Excel来对数据进行各种运算,如:求和.求差.求积等公式,来完成我们的运算.在前面几课中我们已经基本的讲解了各种运算的函数公式,本篇再来说下在Excel表格中如何求平均值.我们在制作一份成绩表排名的时候,知道了各科成绩,需要求出成绩的平均值,我们该如何来完成呢? 首先,这里是一份成绩表,上面有各门功课的成绩,我们要求出平均分数. ①将光标定位到"平均分"下面一个单元格中,然后点击"插入函数"按钮,如下图红色区域便是; ②在弹出的