torch.load()如何加载模型并详细解析map_location参数?
- 内容介绍
- 文章标签
- 相关推荐
本文共计931个文字,预计阅读时间需要4分钟。
目录参考torch.load()模型保存模型加载中的map_location参数map_location=Nonemap_location=torch.device()map_location={xx:xx}总结参考torch.load()函数格式:torch.load(f, map_location=None)
模型加载中的map_location参数作用:指定模型加载到哪个设备上,例如CPU或GPU。
选项:- map_location=None:默认使用与保存模型时相同的设备。- map_location=torch.device():指定设备类型,如torch.device('cpu')或torch.device('cuda:0')。- map_location={xx:xx}:自定义设备映射,如{0: 'cuda:0', 1: 'cpu'}。
总结在使用torch.load()加载模型时,通过map_location参数可以控制模型加载到特定的设备上,以确保模型能够正确运行。
目录
- 参考
- torch.load()
- 模型的保存
- 模型加载中的map_location参数
- map_location=None
- map_location=torch.device()
- map_location={xx:xx}
- 总结
参考
TORCH.LOAD
torch.load()
函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。
模型的保存
模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。
另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
具体可参考:PyTorch模型的保存与加载。
模型加载中的map_location参数
具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。
首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。
map_location=None
我们先把state_dict加载进来。
model_path = "./cuda_model.pth" model = torch.load(model_path) print(next(model.parameters()).device)
结果为:
cuda:0
因为保存的时候就是模型就是cuda:0的,所以加载进来也是。
map_location=torch.device()
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location=torch.device('cpu')) print(next(model.parameters()).device)
结果为:
cpu
模型从cuda:0变成了cpu。
map_location={xx:xx}
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location={'cuda:0':'cuda:1'}) print(next(model.parameters()).device)
结果为:
cuda:1
模型从cuda:0变成了cuda:1。
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location={'cuda:2':'cpu'}) print(next(model.parameters()).device)
结果为:
cuda:0
模型还是cuda:0,并没有变成cpu。因为这个map_location的映射是不对的,原始的模型就是cuda:0,而映射是cuda:2到cpu,是不对的。这种情况下,map_location返回None,也就是和不加map_location相同。
总结
到此这篇关于Python中torch.load()加载模型以及其map_location参数详解的文章就介绍到这了,更多相关torch.load()加载模型map_location参数内容请搜索自由互联以前的文章或继续浏览下面的相关文章希望大家以后多多支持自由互联!
本文共计931个文字,预计阅读时间需要4分钟。
目录参考torch.load()模型保存模型加载中的map_location参数map_location=Nonemap_location=torch.device()map_location={xx:xx}总结参考torch.load()函数格式:torch.load(f, map_location=None)
模型加载中的map_location参数作用:指定模型加载到哪个设备上,例如CPU或GPU。
选项:- map_location=None:默认使用与保存模型时相同的设备。- map_location=torch.device():指定设备类型,如torch.device('cpu')或torch.device('cuda:0')。- map_location={xx:xx}:自定义设备映射,如{0: 'cuda:0', 1: 'cpu'}。
总结在使用torch.load()加载模型时,通过map_location参数可以控制模型加载到特定的设备上,以确保模型能够正确运行。
目录
- 参考
- torch.load()
- 模型的保存
- 模型加载中的map_location参数
- map_location=None
- map_location=torch.device()
- map_location={xx:xx}
- 总结
参考
TORCH.LOAD
torch.load()
函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。
模型的保存
模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。
另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
具体可参考:PyTorch模型的保存与加载。
模型加载中的map_location参数
具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。
首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。
map_location=None
我们先把state_dict加载进来。
model_path = "./cuda_model.pth" model = torch.load(model_path) print(next(model.parameters()).device)
结果为:
cuda:0
因为保存的时候就是模型就是cuda:0的,所以加载进来也是。
map_location=torch.device()
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location=torch.device('cpu')) print(next(model.parameters()).device)
结果为:
cpu
模型从cuda:0变成了cpu。
map_location={xx:xx}
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location={'cuda:0':'cuda:1'}) print(next(model.parameters()).device)
结果为:
cuda:1
模型从cuda:0变成了cuda:1。
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location={'cuda:2':'cpu'}) print(next(model.parameters()).device)
结果为:
cuda:0
模型还是cuda:0,并没有变成cpu。因为这个map_location的映射是不对的,原始的模型就是cuda:0,而映射是cuda:2到cpu,是不对的。这种情况下,map_location返回None,也就是和不加map_location相同。
总结
到此这篇关于Python中torch.load()加载模型以及其map_location参数详解的文章就介绍到这了,更多相关torch.load()加载模型map_location参数内容请搜索自由互联以前的文章或继续浏览下面的相关文章希望大家以后多多支持自由互联!

