Why is PyTorch failing to export the 'Scatter' function call in Python?
- 内容介绍
- 文章标签
- 相关推荐
本文共计354个文字,预计阅读时间需要2分钟。
在使用PyTorch进行模型导出时,如果遇到错误 RuntimeError: Could not export Python function call 'Scatter'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation?,通常是因为模型中包含了不能直接导出的Python函数调用。
以下是修改后的内容,确保不超过100字:
PyTorch模型导出时出错:无法导出包含Python函数 'Scatter' 的调用。请确保移除Python函数调用,并添加 @script 或 @script_method 注解。
pytorch
用pytorch的 trace 导出模型的时候,报错
error
RuntimeError:Could not export Python function call 'Scatter'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(13): scatter_map
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(15): scatter_map
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(28): scatter
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(36): scatter_kwargs
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py(168): scatter
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py(157): forward
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(709): _slow_forward
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(725): _call_impl
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(940): trace_module
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(742): trace
<ipython-input-14-e92379b43790>(2): <module>
解决方案
将model改为
model = model.module
本文共计354个文字,预计阅读时间需要2分钟。
在使用PyTorch进行模型导出时,如果遇到错误 RuntimeError: Could not export Python function call 'Scatter'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation?,通常是因为模型中包含了不能直接导出的Python函数调用。
以下是修改后的内容,确保不超过100字:
PyTorch模型导出时出错:无法导出包含Python函数 'Scatter' 的调用。请确保移除Python函数调用,并添加 @script 或 @script_method 注解。
pytorch
用pytorch的 trace 导出模型的时候,报错
error
RuntimeError:Could not export Python function call 'Scatter'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(13): scatter_map
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(15): scatter_map
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(28): scatter
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(36): scatter_kwargs
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py(168): scatter
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py(157): forward
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(709): _slow_forward
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(725): _call_impl
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(940): trace_module
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(742): trace
<ipython-input-14-e92379b43790>(2): <module>
解决方案
将model改为
model = model.module

