在PyTorch中保存经过训练的模型的最佳方法?
问题内容:
我一直在寻找其他方法来在PyTorch中保存经过训练的模型。到目前为止,我发现了两种选择。
- 使用torch.save()保存模型,使用torch.load()加载模型。
- model.state_dict()保存训练的模型,model.load_state_dict()加载保存的模型。
我碰到过这种讨论,其中建议方法2优于方法1。
我的问题是,为什么选择第二种方法呢?仅仅是因为torch.nn模块具有这两个功能,我们被鼓励使用它们吗?
问题答案:
我在他们的github仓库中找到了此页面,我将内容粘贴在这里。
推荐的模型保存方法
序列化和还原模型有两种主要方法。
第一个(推荐)仅保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后再:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二个保存并加载整个模型:
torch.save(the_model, PATH)
然后再:
the_model = torch.load(PATH)
但是,在这种情况下,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。