在PyTorch中保存经过训练的模型的最佳方法?


问题内容

我一直在寻找其他方法来在PyTorch中保存经过训练的模型。到目前为止,我发现了两种选择。

  1. 使用torch.save()保存模型,使用torch.load()加载模型。
  2. model.state_dict()保存训练的模型,model.load_state_dict()加载保存的模型。

我碰到过这种讨论,其中建议方法2优于方法1。

我的问题是,为什么选择第二种方法呢?仅仅是因为torch.nn模块具有这两个功能,我们被鼓励使用它们吗?


问题答案:

我在他们的github仓库中找到了此页面,我将内容粘贴在这里。


推荐的模型保存方法

序列化和还原模型有两种主要方法。

第一个(推荐)仅保存和加载模型参数:

torch.save(the_model.state_dict(), PATH)

然后再:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

然后再:

the_model = torch.load(PATH)

但是,在这种情况下,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。