Pytorch: 功能请求:load_state_dict 应该采用文件名

创建于 2017-05-31  ·  3评论  ·  资料来源: pytorch/pytorch

在高内存压力的情况下,以下是常见的情况:

  1. 创建模型
  2. 从检查点文件中读取 state_dict(在 GPU 上加载)
  3. model.load_state_dict(s)

由于内存压力,一个常见的解决方法是首先执行以下操作:

s = torch.load('my_file.pt', map_location=lambda storage, loc: storage)

然后将s加载到model中。

这是我们应该能够避免的非常常见的场景,并且这种场景可能存在一些陷阱:在部分 GPU 部分 CPU 模型上会发生什么,在多 GPU 模型上会发生什么......

如果 load_state_dict 直接取一个文件名,它可以删除它现有的参数存储并即时将它们设置为新的,因此不需要额外的内存。

feature nn triaged

最有用的评论

如果load_state_dict采用文件名,我们也应该允许map_location参数。 我的一个常见情况是在集群机器上保存一个检查点,然后将其加载到我的 macbook 上(因此需要将参数加载到 CPU 上)

所有3条评论

这同样适用于优化器 state_dicts。 对于像 Adagrad 这样的优化器,检查点很大,我们可以有相同的内存压力情况。 优化器甚至没有.cuda() ,所以我们首先必须手动将 state_dict 加载到 CPU 上,然后手动将部分复制到 GPU。

我今天在帮助@aszlam时遇到了这个问题。

如果load_state_dict采用文件名,我们也应该允许map_location参数。 我的一个常见情况是在集群机器上保存一个检查点,然后将其加载到我的 macbook 上(因此需要将参数加载到 CPU 上)

我和@szagoruyko是序列化模型的 HDF5 格式的粉丝,如果它可以很好地配合这个提议的话

此页面是否有帮助?
0 / 5 - 0 等级

相关问题

SeparateReality picture SeparateReality  ·  3评论

mishraswapnil picture mishraswapnil  ·  3评论

szagoruyko picture szagoruyko  ·  3评论

miguelvr picture miguelvr  ·  3评论

Coderx7 picture Coderx7  ·  3评论