系统城装机大师 - 固镇县祥瑞电脑科技销售部宣传站!

当前位置:首页 > 脚本中心 > python > 详细页面

PyTorch中的Variable变量详解

时间:2020-01-07来源:系统城作者:电脑系统城

一、了解Variable

顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。

具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。

【tensor 是一个多维矩阵】

用一个例子说明,Variable的定义:


 
  1. import torch
  2. from torch.autograd import Variable # torch 中 Variable 模块
  3. tensor = torch.FloatTensor([[1,2],[3,4]])
  4. # 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
  5. variable = Variable(tensor, requires_grad=True)
  6.  
  7. print(tensor)
  8. """
  9. 1 2
  10. 3 4
  11. [torch.FloatTensor of size 2x2]
  12. """
  13.  
  14. print(variable)
  15. """
  16. Variable containing:
  17. 1 2
  18. 3 4
  19. [torch.FloatTensor of size 2x2]
  20. """

注:tensor不能反向传播,variable可以反向传播。

二、Variable求梯度

Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。


 
  1. v_out.backward() # 模拟 v_out 的误差反向传递
  2.  
  3. print(variable.grad) # 初始 Variable 的梯度
  4. '''
  5. 0.5000 1.0000
  6. 1.5000 2.0000
  7. '''

三、获取Variable里面的数据

直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。


 
  1. print(variable) # Variable 形式
  2. """
  3. Variable containing:
  4. 1 2
  5. 3 4
  6. [torch.FloatTensor of size 2x2]
  7. """
  8.  
  9. print(variable.data) # 将variable形式转为tensor 形式
  10. """
  11. 1 2
  12. 3 4
  13. [torch.FloatTensor of size 2x2]
  14. """
  15.  
  16. print(variable.data.numpy()) # numpy 形式
  17. """
  18. [[ 1. 2.]
  19. [ 3. 4.]]
  20. """

扩展

在PyTorch中计算图的特点总结如下:

autograd根据用户对Variable的操作来构建其计算图。

1、requires_grad

variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。

2、volatile

variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。

3、retain_graph

多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。

4、backward()

反向传播,求解Variable的梯度。放在中间缓存中。

以上这篇PyTorch中的Variable变量详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

分享到:

相关信息

系统教程栏目

栏目热门教程

人气教程排行

站长推荐

热门系统下载