导语:在使用 PyTorch 搭建网络的过程中,经常遇到一些 python 的基础知识,不了解的话就会卡住半天想不出来,在此先总结一下这段时间遇到的问题。
Python的重载
问题来源:
1 | class LeNet(nn.Module): |
在用 pytorch 搭建这个网络的时候,一个问题引起了我的注意。作为最开始 java 的程序员学习编程的人来说,*面向对象(OOP)*的概念应该是很基础的一个问题了,但是,我却在这里犯了愁。请看这一行:
1 | nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=(2, 2)) |
按住ctrl
,点击Conv2d
, Pycharm 自动将我带到了定义它的地方。这是一个类,如下
其继承关系为
既然是一个类,那么self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=(2, 2))
这句话的意思就是将该类实例化,self.conv1
这个变量指向实例化之后的对象,这没什么问题,但是,为什么在后面的forward
函数中会出现self.conv1(t)
这么一个奇怪的语法呢?这是对变量进行传参???这个问题很奇怪,以至于我都不知道该怎么百度搜。
后来,我在Module
中找到了这么一个语句:
1 | __call__ : Callable[..., Any] = _call_impl |
百度 python 的__call__
用法之后,一切都解决了。原来这是个类似于 python 重载的东西,它使得类实例对象可以像调用普通函数那样,以 “对象名()” 的形式使用。
于是我写了个测试去验证
1 | class FunctionLike(object): |
输出结果为
1 | I got called with 10 |
疑问解决。
PyTorch 中,nn 与 nn.functional 有什么区别? - 有糖吃可好的回答 - 知乎
nn.functional.xxx
是函数接口,而nn.Xxx
是nn.functional.xxx
的类封装,并且**nn.Xxx
都继承于一个共同祖先nn.Module
。**
-
nn.Xxx
除了具有nn.functional.xxx
功能之外,内部附带了nn.Module
相关的属性和方法,例如train(), eval(), load_state_dict, state_dict
等。 -
nn.Xxx
继承于nn.Module
, 能够很好的与nn.Sequential
结合使用, 而nn.functional.xxx
无法与nn.Sequential
结合使用。 -
nn.Xxx
不需要你自己定义和管理weight;而nn.functional.xxx
需要你自己定义weight,每次调用的时候都需要手动传入weight, 不利于代码复用。
阅读源码发现 nn.Xxx
里面的是继承自nn.module
初始化为实例化一个类,如果含参数,则会帮你初始化好参数。而nn.functional.xxx
里面则是相当于直接一个函数句柄给你,如果需要参数,则需要你自己输入参数,且并不会“记住”这个参数。
使用原则:
如果要涉及到参数计算的,那用nn.Xxx
里的;如果不需要涉及更新参数,只是一次性计算,那用nn.functional.xxx
里面的。
PyTorch
官方推荐:具有学习参数的(例如,conv2d
, linear
, batch_norm
)采用nn.Xxx
方式,没有学习参数的(例如,maxpool
, loss func
, activation func
)等根据个人选择使用nn.functional.xxx
或者nn.Xxx
方式。