Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue deploying libra-retinanet to onnx due to BFP neck component #2763

Open
3 tasks done
reiffd7 opened this issue May 9, 2024 · 0 comments
Open
3 tasks done

issue deploying libra-retinanet to onnx due to BFP neck component #2763

reiffd7 opened this issue May 9, 2024 · 0 comments

Comments

@reiffd7
Copy link

reiffd7 commented May 9, 2024

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

Can't deploy libra-retinanet. Tested with and without the BFP neck component and it works without this component. Wondering if there is a pytorch or onnx versioning problem here leading to not being able to handle this component. Anyone encounter this?

Reproduction

I ran the following mmdeploy command:

python3 ${MMDEPLOY_PATH}/tools/deploy.py \
    ${MMDEPLOY_PATH}/configs/mmdet/detection/detection_tensorrt-int8_static-544x960.py \
    ${ROOT_PATH}/model_data/person_detection/${MODEL_CHECKPOINT_NAME}/configs/libra-retinanet_r50_fpn_1x_coco.py \
    ${ROOT_PATH}/model_data/person_detection/${MODEL_CHECKPOINT_NAME}/models/pytorch/bfp_cross_entropy_epoch_44.pth \
    ${INPUT_IMG} \
    --work-dir ${WORK_DIR} \
    --device cuda \
    --log-level ERROR \
    --dump-info

The model config looks like this:

model = dict(
    type='RetinaNet',
    data_preprocessor=dict(
        type='DetDataPreprocessor',
        mean=[103.53, 116.28, 123.675],
        std=[57.375, 57.12, 58.395],
        bgr_to_rgb=False,
        pad_size_divisor=32),
    backbone=dict(
        type='RegNet',
        arch='regnetx_400mf',
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch',
        init_cfg=dict(
            type='Pretrained', checkpoint='open-mmlab://regnetx_400mf')
    ),
    neck=[
        dict(
            type='FPN',
            in_channels=[32, 64, 160, 384],
            out_channels=256,
            start_level=1,
            add_extra_convs='on_input',
            num_outs=5),
        dict(
            type='BFP',
            in_channels=256,
            num_levels=5,
            refine_level=1,
            refine_type='non_local')
    ],
    bbox_head=dict(
        type='ATSSHead',
        num_classes=1,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        norm_cfg=None,
        anchor_generator=dict(
            type='AnchorGenerator',
            ratios=[1.0],
            octave_base_scale=8,
            scales_per_octave=1,
            strides=[8, 16, 32, 64, 128]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[0.0, 0.0, 0.0, 0.0],
            target_stds=[0.1, 0.1, 0.2, 0.2]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
        loss_centerness=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)  ),
    # model training and testing settings
    train_cfg=dict(
        assigner=dict(type='ATSSAssigner', topk=9),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    test_cfg=dict(
        nms_pre=1000,
        min_bbox_size=16,
        score_thr=0.25,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=300))

Deployment runs fine without the BFP neck.

Environment

FROM openmmlab/mmdeploy:ubuntu20.04-cuda11.3-mmdeploy1.1.0
#FROM openmmlab/mmdeploy:ubuntu20.04-cuda11.8-mmdeploy

WORKDIR /root/workspace

RUN pip install joblib
RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113

RUN pip install -U pip
RUN pip install joblib
RUN pip install openmim

RUN pip install fire onnx onnxsim sclblonnx
RUN mim install mmcv==2.1
RUN mim install mmrazor
RUN mim install mmdet>3
RUN mim install mmdeploy==1.3.0

RUN apt-get clean && rm -rf /var/lib/apt/lists/*


### Error traceback

```Shell
Convert PyTorch model to ONNX graph...
05/09 18:44:21 - mmengine - ERROR - /root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - __call__ - 94 - Start pipeline mmdeploy.apis.pytorch2onnx.torch2onnx in subprocess
05/09 18:44:22 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
05/09 18:44:22 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
/usr/local/lib/python3.8/dist-packages/mmdet/models/dense_heads/anchor_head.py:108: UserWarning: DeprecationWarning: `num_anchors` is deprecated, for consistency or also use `num_base_priors` instead
  warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
Loads checkpoint by local backend from path: /opt/ml/forsight-model-optimization/model_data/person_detection/retinanet_regnetx_400mf_fpn_1x8_1x_person_detection_19_04_2024_subset/models/pytorch/bfp_cross_entropy_epoch_44.pth
05/09 18:44:25 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future. 
05/09 18:44:25 - mmengine - INFO - Export PyTorch model to ONNX: /opt/ml/forsight-model-optimization/model_data/person_detection/retinanet_regnetx_400mf_fpn_1x8_1x_person_detection_19_04_2024_subset/models/onnx/end2end.onnx.
05/09 18:44:25 - mmengine - WARNING - Can not find torch._C._jit_pass_onnx_autograd_function_process, function rewrite will not be applied
05/09 18:44:25 - mmengine - WARNING - Can not find mmdet.models.utils.transformer.PatchMerging.forward, function rewrite will not be applied
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:84: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  img_shape = [int(val) for val in img_shape]
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:84: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  img_shape = [int(val) for val in img_shape]
/root/workspace/mmdeploy/mmdeploy/core/optimizers/function_marker.py:160: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  ys_shape = tuple(int(s) for s in ys.shape)
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py:109: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
/root/workspace/mmdeploy/mmdeploy/pytorch/functions/topk.py:58: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if k > size:
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py:38: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert pred_bboxes.size(0) == bboxes.size(0)
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py:40: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert pred_bboxes.size(1) == bboxes.size(1)
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:451: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  int(scores.shape[-1]),
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:148: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  out_boxes = min(num_boxes, after_topk)
Process Process-2:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1187, in symbolic_fn
    output_size = symbolic_helper._parse_arg(output_size, "is")
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 97, in _parse_arg
    raise RuntimeError(
RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 98, in torch2onnx
    export(
  File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 356, in _wrap
    return self.call_function(func_name_, *args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 326, in call_function
    return self.call_function_local(func_name, *args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 275, in call_function_local
    return pipe_caller(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/onnx/export.py", line 131, in export
    torch.onnx.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/root/workspace/mmdeploy/mmdeploy/apis/onnx/optimizer.py", line 27, in model_to_graph__custom_optimizer
    graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 731, in _model_to_graph
    graph = _optimize_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 308, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 416, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1406, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 308, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1189, in symbolic_fn
    return symbolic_helper._onnx_unsupported(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 454, in _onnx_unsupported
    raise RuntimeError(
RuntimeError: Unsupported: ONNX export of operator adaptive pooling, since output_size is not constant.. Please feel free to request support or submit a pull request on PyTorch GitHub.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant