导语:在使用 PyTorch 搭建网络的过程中,经常遇到一些 python 的基础知识,不了解的话就会卡住半天想不出来,在此先总结一下这段时间遇到的问题。

Python的重载

问题来源:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=(2, 2))
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5))

self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.out = nn.Linear(in_features=84, out_features=10)

def forward(self, t):
t = F.relu(self.conv1(t)) # 激活操作
t = F.max_pool2d(t, kernel_size=2, stride=2) # 池化操作
t = F.relu(self.conv2(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = t.view(-1, num_flat_features(t))
t = F.relu(self.fc1(t))
t = F.relu(self.fc2(t))
t = self.out(t)

return t

在用 pytorch 搭建这个网络的时候,一个问题引起了我的注意。作为最开始 java 的程序员学习编程的人来说,*面向对象(OOP)*的概念应该是很基础的一个问题了,但是,我却在这里犯了愁。请看这一行:

1
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=(2, 2))

按住ctrl,点击Conv2d, Pycharm 自动将我带到了定义它的地方。这是一个类,如下

其继承关系为

既然是一个类,那么self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=(2, 2))这句话的意思就是将该类实例化,self.conv1这个变量指向实例化之后的对象,这没什么问题,但是,为什么在后面的forward函数中会出现self.conv1(t)这么一个奇怪的语法呢?这是对变量进行传参???这个问题很奇怪,以至于我都不知道该怎么百度搜。

后来,我在Module中找到了这么一个语句:

1
__call__ : Callable[..., Any] = _call_impl

百度 python 的__call__用法之后,一切都解决了。原来这是个类似于 python 重载的东西,它使得类实例对象可以像调用普通函数那样,以 “对象名()” 的形式使用。

可参考python中的 call()

于是我写了个测试去验证

1
2
3
4
5
6
7
8
class FunctionLike(object):
def __call__(self, a):
print("I got called with", a)


fn = FunctionLike()

fn(10)

输出结果为

1
I got called with 10

疑问解决。

PyTorch 中,nn 与 nn.functional 有什么区别? - 有糖吃可好的回答 - 知乎

nn.functional.xxx是函数接口,而nn.Xxxnn.functional.xxx的类封装,并且**nn.Xxx都继承于一个共同祖先nn.Module。**

  • nn.Xxx除了具有nn.functional.xxx功能之外,内部附带了nn.Module相关的属性和方法,例如train(), eval(), load_state_dict, state_dict 等。

  • nn.Xxx继承于nn.Module能够很好的与nn.Sequential结合使用, 而nn.functional.xxx无法与nn.Sequential结合使用。

  • nn.Xxx不需要你自己定义和管理weight;而nn.functional.xxx需要你自己定义weight,每次调用的时候都需要手动传入weight, 不利于代码复用。

阅读源码发现 nn.Xxx里面的是继承自nn.module 初始化为实例化一个类,如果含参数,则会帮你初始化好参数。而nn.functional.xxx里面则是相当于直接一个函数句柄给你,如果需要参数,则需要你自己输入参数,且并不会“记住”这个参数。

使用原则:

如果要涉及到参数计算的,那用nn.Xxx 里的;如果不需要涉及更新参数,只是一次性计算,那用nn.functional.xxx里面的。

PyTorch官方推荐:具有学习参数的(例如,conv2d, linear, batch_norm)采用nn.Xxx方式,没有学习参数的(例如,maxpool, loss func, activation func)等根据个人选择使用nn.functional.xxx或者nn.Xxx方式。