You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have searched related issues but cannot get the expected help.
2. I have read the FAQ documentation but cannot get the expected help.
3. The bug has not been fixed in the latest version.
Describe the bug
Can't deploy libra-retinanet. Tested with and without the BFP neck component and it works without this component. Wondering if there is a pytorch or onnx versioning problem here leading to not being able to handle this component. Anyone encounter this?
FROM openmmlab/mmdeploy:ubuntu20.04-cuda11.3-mmdeploy1.1.0
#FROM openmmlab/mmdeploy:ubuntu20.04-cuda11.8-mmdeploy
WORKDIR /root/workspace
RUN pip install joblib
RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
RUN pip install -U pip
RUN pip install joblib
RUN pip install openmim
RUN pip install fire onnx onnxsim sclblonnx
RUN mim install mmcv==2.1
RUN mim install mmrazor
RUN mim install mmdet>3
RUN mim install mmdeploy==1.3.0
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
### Error traceback
```Shell
Convert PyTorch model to ONNX graph...
05/09 18:44:21 - mmengine - ERROR - /root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - __call__ - 94 - Start pipeline mmdeploy.apis.pytorch2onnx.torch2onnx in subprocess
05/09 18:44:22 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
05/09 18:44:22 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
/usr/local/lib/python3.8/dist-packages/mmdet/models/dense_heads/anchor_head.py:108: UserWarning: DeprecationWarning: `num_anchors` is deprecated, for consistency or also use `num_base_priors` instead
warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
Loads checkpoint by local backend from path: /opt/ml/forsight-model-optimization/model_data/person_detection/retinanet_regnetx_400mf_fpn_1x8_1x_person_detection_19_04_2024_subset/models/pytorch/bfp_cross_entropy_epoch_44.pth
05/09 18:44:25 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future.
05/09 18:44:25 - mmengine - INFO - Export PyTorch model to ONNX: /opt/ml/forsight-model-optimization/model_data/person_detection/retinanet_regnetx_400mf_fpn_1x8_1x_person_detection_19_04_2024_subset/models/onnx/end2end.onnx.
05/09 18:44:25 - mmengine - WARNING - Can not find torch._C._jit_pass_onnx_autograd_function_process, function rewrite will not be applied
05/09 18:44:25 - mmengine - WARNING - Can not find mmdet.models.utils.transformer.PatchMerging.forward, function rewrite will not be applied
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:84: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
img_shape = [int(val) for val in img_shape]
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:84: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
img_shape = [int(val) for val in img_shape]
/root/workspace/mmdeploy/mmdeploy/core/optimizers/function_marker.py:160: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
ys_shape = tuple(int(s) for s in ys.shape)
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py:109: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
/root/workspace/mmdeploy/mmdeploy/pytorch/functions/topk.py:58: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if k > size:
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py:38: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert pred_bboxes.size(0) == bboxes.size(0)
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py:40: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert pred_bboxes.size(1) == bboxes.size(1)
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:451: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
int(scores.shape[-1]),
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:148: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
out_boxes = min(num_boxes, after_topk)
Process Process-2:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1187, in symbolic_fn
output_size = symbolic_helper._parse_arg(output_size, "is")
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 97, in _parse_arg
raise RuntimeError(
RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
ret = func(*args, **kwargs)
File "/root/workspace/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 98, in torch2onnx
export(
File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 356, in _wrap
return self.call_function(func_name_, *args, **kwargs)
File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 326, in call_function
return self.call_function_local(func_name, *args, **kwargs)
File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 275, in call_function_local
return pipe_caller(*args, **kwargs)
File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
ret = func(*args, **kwargs)
File "/root/workspace/mmdeploy/mmdeploy/apis/onnx/export.py", line 131, in export
torch.onnx.export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
return utils.export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
_export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1074, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/root/workspace/mmdeploy/mmdeploy/apis/onnx/optimizer.py", line 27, in model_to_graph__custom_optimizer
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 731, in _model_to_graph
graph = _optimize_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 308, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 416, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1406, in _run_symbolic_function
return symbolic_fn(g, *inputs, **attrs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 308, in wrapper
return fn(g, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1189, in symbolic_fn
return symbolic_helper._onnx_unsupported(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 454, in _onnx_unsupported
raise RuntimeError(
RuntimeError: Unsupported: ONNX export of operator adaptive pooling, since output_size is not constant.. Please feel free to request support or submit a pull request on PyTorch GitHub.
The text was updated successfully, but these errors were encountered:
Checklist
Describe the bug
Can't deploy libra-retinanet. Tested with and without the BFP neck component and it works without this component. Wondering if there is a pytorch or onnx versioning problem here leading to not being able to handle this component. Anyone encounter this?
Reproduction
I ran the following mmdeploy command:
The model config looks like this:
Deployment runs fine without the BFP neck.
Environment
The text was updated successfully, but these errors were encountered: