天顺娱乐-天顺平台注册站
 
 
pytorch源码阅读(二)optimizer原理
来源:网络 时间:2024-03-04 12:40

pytorch包含多种优化算法用于网络参数的更新,比如常用的SGD、Adam、LBFGS以及RMSProp等。使用中可以发现各种优化算法的使用方式几乎相同,是因为父类optimizer【1】定义了各个子类(即SGD等)的核心行为,下面是optimizer类注释:

class Optimizer(object):
    r"""Base class for all optimizers.
    Arguments:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
    """

其中首句“所有优化器的基类” 表明所有的优化器都必须继承optimizer类,下面来分析optimizer类的的各个实例函数。

1、初始化__init__()

def __init__(self, params, defaults):
    torch._C._log_api_usage_once("python.optimizer")
    self.defaults = defaults
    self.state = defaultdict(dict)
    self.param_groups = []
    param_groups = list(params)
    # 省略类型检查
    for param_group in param_groups:
        self.add_param_group(param_group)

优化器需要保存学习率等参数的值,所以optimizer类需要用实例属性来存储这些参数,也就是__init__()中的self.param_groups,下面的代码通过一个全连接网络来测试优化器的param_groups包含哪些参数:

net = nn.Linear(2, 2)
# 权重矩阵初始化为1
nn.init.constant_(net.weight, val=100)
nn.init.constant_(net.bias, val=20)
optimizer = optim.SGD(net.parameters(), lr=0.01)
print(optimizer.param_groups)

得到:

[{'params': [Parameter containing:
tensor([[ 100.,  100.],
        [ 100.,  100.]]), Parameter containing:
tensor([20,, 20])], 'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

其中2x2的矩阵是net的权重矩阵,1x2为偏置矩阵,其余为优化器的其它参数,所以说param_groups保存了优化器的全部数据,这个下面的state_dict()不同。


2、优化器状态state_dict()

def state_dict(self):
    r"""Returns the state of the optimizer as a :class:`dict` """
    # Save ids instead of Tensors
    def pack_group(group):
        # 对"params"和其它的键采用不同规则
        packed = {k: v for k, v in group.items() if k != 'params'}
        # 这里并没有保存参数的值,而是保存参数的id
        packed['params'] = [id(p) for p in group['params']]
        return packed
    # 对self.param_groups进行遍历
    param_groups = [pack_group(g) for g in self.param_groups]
    # Remap state to use ids as keys
    packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                    for k, v in self.state.items()}
    # 返回状态和参数组,其中参数组才是优化器的参数
    return {
        'state': packed_state,
        'param_groups': param_groups,
    }

查看上一节定义的optimizer的state_dict():

print(optimizer.state_dict()["param_groups"])

可以到优化器的完整参数如下:

[{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 
'nesterov': False, 'params': [2149749904224, 2149749906312]}]


3、优化器参数加载load_state_dict()

上一节中的state_dict()负责提取优化器的参数,可以保存到本地用于下次训练恢复使用,对应的必然有load_state_dict()用于优化器参数的加载,其源码如下:

def load_state_dict(self, state_dict):
    r"""Loads the optimizer state.
    Arguments:
        state_dict (dict): optimizer state. Should be an object returned
            from a call to :meth:`state_dict`.
    """
    # deepcopy, to be consistent with module API
    # 应该是防止函数中对输入的state_dict进行改动
    # 因为字典是可变数据类型
    state_dict = deepcopy(state_dict)
    # Validate the state_dict
    groups = self.param_groups
    saved_groups = state_dict['param_groups']
    # 参数的长度检测,保证输入的state_dict和优化器的参数数目一致
    if len(groups) != len(saved_groups):
        raise ValueError("loaded state dict has a different number of "
                         "parameter groups")
    param_lens = (len(g['params']) for g in groups)
    saved_lens = (len(g['params']) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError("loaded state dict contains a parameter group "
                         "that doesn't match the size of optimizer's group")
    # 用输入的state_dict更新当前state_dict的状态                     
      id_map = {old_id: p for old_id, p in
              zip(chain(*(g['params'] for g in saved_groups)),
                  chain(*(g['params'] for g in groups)))}
    # 省略具体实现              

为了测试state_dict()和load_state_dict(),可以首先存储一个学习率为100的优化器的参数到本地:

optimizer_old = optim.SGD(net.parameters(), lr=100) 
torch.save(optimizer_old.state_dict(), "optim_old.npy")

现在这个优化器的参数已经存储到本地,然后将这个优化器参数重新加载给一个新的学习率为0.01优化器:

optimizer_new = optim.SGD(net.parameters(), lr=0.01)
old_state = torch.load("optim_old.npy")
# 将之前定义的优化器参数给新的优化器
optimizer_new.load_state_dict(old_state)
print(optimizer_new.state_dict()["param_groups"])

得到new优化器的学习率不是0.01,而是old优化器的学习率100:

[{'lr': 100, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2122843345256, 2122843345112]}]


4、梯度清空zero_grad()

在网络优化过程中optimizer.zero_grad()函数需要被显式调用,负责清空其关联网络的参数梯度值,其源码如下:

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    # 获取每一组参数
    for group in self.param_groups:
        # 遍历当前参数组所有的params
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

这个遍历过程就是获取optimizer的param_groups属性的字典,之中的["params"],之中的所有参数,通过遍历设定每个参数的梯度值为0。

5、单步更新step()

def step(self, closure):
    r"""Performs a single optimization step (parameter update).
    Arguments:
        closure (callable): A closure that reevaluates the model and
            returns the loss. Optional for most optimizers.
    """
    raise NotImplementedError

优化器的step()函数负责更新参数值,但是其具体实现对于不同的优化算法是不同的,所以optimizer类只是定义了这种行为,但是并没有给出具体实现。

6、总结

优化算法部分的代码并不多,但是不同的优化算法涉及的概念较多,看懂各种算法的实现需要很强的数学功底。optimizer类定义了各种优化算法的公共行为与抽象方法,是典型的面向对象的继承思想。


参考:

【1】github.com/pytorch/pyto

 

联系我们

400-123-4567 仅限中国 9:00-20:00
微信二维码
Copyright © 2002-2022 天顺娱乐-天顺平台注册站 版权所有    粤IP********    

平台注册入口