195 字
1 分钟
pytorch训练模型的保存和加载方式
方法一
TIP保存模型结构以及权重参数
保存
torch.save(net,'name.pth')#name和pth均可自定义加载
TIP加载时还是要带上模型的结构
#一个小陷阱,不带上原网络模型会报错class Net (nn.Module): def __init__(self): super().__init__( self.f1 = nn.Linear(28 * 28,64) self.f2 = nn.Linear(64,64) self.f3 = nn.Linear(64,64) self.f4 = nn.Linear(64,10)
def forward(self,x): x = relu(self.f1(x)) x = relu(self.f2(x)) x = relu(self.f3(x)) x = log_softmax(self.f4(x), dim=1) return x后再接
net=torch.load('name.pth')方法二
TIP仅保存模型的权重参数(官方推荐)
保存
torch.save(net.state_dict(),'name.pth')#net为你的模型名加载
#同样需要先定义模型结构class Net (nn.Module): def __init__(self): super().__init__( self.f1 = nn.Linear(28 * 28,64) self.f2 = nn.Linear(64,64) self.f3 = nn.Linear(64,64) self.f4 = nn.Linear(64,10)
def forward(self,x): x = relu(self.f1(x)) x = relu(self.f2(x)) x = relu(self.f3(x)) x = log_softmax(self.f4(x), dim=1) return x
net=Net()#实例化模型net.load_state_dict(torch.load('name.pth'))NOTE根据保存方式,选择对应的加载方法
pytorch训练模型的保存和加载方式
https://mizuki.mysqil.com/posts/pytorch_model_saveのload/ 部分信息可能已经过时









