Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] RTMDet model to ONNXRuntime or Torchscript doesn't produce masks #2735

Open
3 tasks done
soulslicer opened this issue Apr 11, 2024 · 0 comments
Open
3 tasks done

Comments

@soulslicer
Copy link

soulslicer commented Apr 11, 2024

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

When converting the pretrained RTMDet model to the ONNXRuntime or Torchscript via MMdeploy, bboxes and scores are generated but no mask. Code to reproduce below

Reproduction

Code to reproduce and run

from mmdet.apis import DetInferencer
import mmcv
import mmengine
from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules
from mmdet.registry import VISUALIZERS
import numpy as np
import cv2
import time
import os
user = os.environ.get("USER")
cache = f"/home/{user}/.cache/torch/hub/"

inferencer = DetInferencer('rtmdet-ins_l_8xb32-300e_coco')
inferencer('demo/demo.jpg')

config_file = 'configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
checkpoint_file = cache + 'checkpoints/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth'
register_all_modules()

model = init_detector(config_file, checkpoint_file, device='cuda:0')  # or device='cuda:0'
image = mmcv.imread('demo/demo.jpg',channel_order='rgb')
result = inference_detector(model, image)

print(result.pred_instances.all_keys())
print("*****************")

from mmdeploy.apis import torch2onnx
from mmdeploy.backend.sdk.export_info import export2SDK
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config
import torch
import time

img = 'demo/demo.jpg'
work_dir = 'mmdeploy_models/mmdet/onnx'
backend_model = ['mmdeploy_models/mmdet/onnx/end2end.onnx']
save_file = 'end2end.onnx'
deploy_cfg = '../mmdeploy/configs/mmdet/detection/detection_onnxruntime_dynamic.py'
model_cfg = 'configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
model_checkpoint = cache + 'checkpoints/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth'
device = 'cuda'

torch2onnx(img, work_dir, save_file, deploy_cfg, model_cfg,
           model_checkpoint, device)

export2SDK(deploy_cfg, model_cfg, work_dir, pth=model_checkpoint,
           device=device)

# read deploy_cfg and model_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)

# build task and backend model
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
model = task_processor.build_backend_model(backend_model)

# process input image
input_shape = get_input_shape(deploy_cfg)
model_inputs, _ = task_processor.create_input(img, input_shape)

# do model inference
with torch.no_grad():
    result = model.test_step(model_inputs)
    print(result[0].pred_instances.all_keys())

Notice that i print the vanilla result keys, and the onnx model result keys.
The first one has ['masks', 'scores', 'priors', 'labels', 'bboxes', 'kernels']
but second one has only ['labels', 'bboxes', 'scores']

Same for torchscript


from mmdeploy.apis import torch2onnx, torch2torchscript
from mmdeploy.backend.sdk.export_info import export2SDK
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config
import torch
import time

img = 'demo/demo.jpg'
work_dir = 'mmdeploy_models/mmdet/torchscript'
backend_model = ['mmdeploy_models/mmdet/torchscript/end2end.pt']
save_file = 'end2end.pt'
deploy_cfg = '../mmdeploy/configs/mmdet/detection/detection_torchscript.py'
model_cfg = 'configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
model_checkpoint = cache + 'checkpoints/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth'
device = 'cuda'



torch2torchscript(img, work_dir, save_file, deploy_cfg, model_cfg,
           model_checkpoint, device)

export2SDK(deploy_cfg, model_cfg, work_dir, pth=model_checkpoint,
           device=device)

# read deploy_cfg and model_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)

# build task and backend model
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
model = task_processor.build_backend_model(backend_model)

# process input image
input_shape = get_input_shape(deploy_cfg)
model_inputs, _ = task_processor.create_input(img, input_shape)

# do model inference
with torch.no_grad():
    result = model.test_step(model_inputs)
    print(result[0].pred_instances.all_keys())

Environment

04/10 19:17:14 - mmengine - INFO - 

04/10 19:17:14 - mmengine - INFO - **********Environmental information**********
04/10 19:17:14 - mmengine - INFO - sys.platform: linux
04/10 19:17:14 - mmengine - INFO - Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
04/10 19:17:14 - mmengine - INFO - CUDA available: True
04/10 19:17:14 - mmengine - INFO - MUSA available: False
04/10 19:17:14 - mmengine - INFO - numpy_random_seed: 2147483648
04/10 19:17:14 - mmengine - INFO - GPU 0: NVIDIA GeForce RTX 3050 Ti Laptop GPU
04/10 19:17:14 - mmengine - INFO - CUDA_HOME: /usr/local/cuda-11.7
04/10 19:17:14 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.7, V11.7.99
04/10 19:17:14 - mmengine - INFO - GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
04/10 19:17:14 - mmengine - INFO - PyTorch: 1.13.0+cu117
04/10 19:17:14 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  - CuDNN 8.5
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

04/10 19:17:14 - mmengine - INFO - TorchVision: 0.14.0+cu117
04/10 19:17:14 - mmengine - INFO - OpenCV: 4.9.0
04/10 19:17:14 - mmengine - INFO - MMEngine: 0.10.3
04/10 19:17:14 - mmengine - INFO - MMCV: 2.1.0
04/10 19:17:14 - mmengine - INFO - MMCV Compiler: GCC 9.3
04/10 19:17:14 - mmengine - INFO - MMCV CUDA Compiler: 11.7
04/10 19:17:14 - mmengine - INFO - MMDeploy: 1.3.1+cfd5d3a
04/10 19:17:14 - mmengine - INFO - 

04/10 19:17:14 - mmengine - INFO - **********Backend information**********
04/10 19:17:14 - mmengine - INFO - tensorrt:	8.6.1
04/10 19:17:14 - mmengine - INFO - tensorrt custom ops:	Available
04/10 19:17:14 - mmengine - INFO - ONNXRuntime:	None
04/10 19:17:14 - mmengine - INFO - ONNXRuntime-gpu:	1.17.1
04/10 19:17:14 - mmengine - INFO - ONNXRuntime custom ops:	Available
04/10 19:17:14 - mmengine - INFO - pplnn:	None
04/10 19:17:14 - mmengine - INFO - ncnn:	1.0.20240410
04/10 19:17:14 - mmengine - INFO - ncnn custom ops:	Available
04/10 19:17:14 - mmengine - INFO - snpe:	None
04/10 19:17:14 - mmengine - INFO - openvino:	2024.0.0
04/10 19:17:14 - mmengine - INFO - torchscript:	1.13.0
04/10 19:17:14 - mmengine - INFO - torchscript custom ops:	Available
04/10 19:17:14 - mmengine - INFO - rknn-toolkit:	None
04/10 19:17:14 - mmengine - INFO - rknn-toolkit2:	None
04/10 19:17:14 - mmengine - INFO - ascend:	None
04/10 19:17:14 - mmengine - INFO - coreml:	None
04/10 19:17:14 - mmengine - INFO - tvm:	None
04/10 19:17:14 - mmengine - INFO - vacc:	None
04/10 19:17:14 - mmengine - INFO - 

04/10 19:17:14 - mmengine - INFO - **********Codebase information**********
04/10 19:17:14 - mmengine - INFO - mmdet:	3.3.0
04/10 19:17:14 - mmengine - INFO - mmseg:	None
04/10 19:17:14 - mmengine - INFO - mmpretrain:	None
04/10 19:17:14 - mmengine - INFO - mmocr:	None
04/10 19:17:14 - mmengine - INFO - mmagic:	None
04/10 19:17:14 - mmengine - INFO - mmdet3d:	None
04/10 19:17:14 - mmengine - INFO - mmpose:	None
04/10 19:17:14 - mmengine - INFO - mmrotate:	None
04/10 19:17:14 - mmengine - INFO - mmaction:	None
04/10 19:17:14 - mmengine - INFO - mmrazor:	None
04/10 19:17:14 - mmengine - INFO - mmyolo:	None

Error traceback

04/10 19:16:06 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
04/10 19:16:06 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
Loads checkpoint by local backend from path: /home/yaadhavraaj/.cache/torch/hub/checkpoints/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth
04/10 19:16:07 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future. 
04/10 19:16:07 - mmengine - INFO - Export PyTorch model to ONNX: mmdeploy_models/mmdet/onnx/end2end.onnx.
04/10 19:16:07 - mmengine - WARNING - Can not find torch.nn.functional.scaled_dot_product_attention, function rewrite will not be applied
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/core/optimizers/function_marker.py:160: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  ys_shape = tuple(int(s) for s in ys.shape)
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/mmcv/ops/nms.py:285: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/mmcv/ops/nms.py:286: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/pytorch/functions/topk.py:28: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  k = torch.tensor(k, device=input.device, dtype=torch.long)
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/mmcv/ops/nms.py:45: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  score_threshold = float(score_threshold)
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/mmcv/ops/nms.py:46: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  iou_threshold = float(iou_threshold)
/home/yaadhavraaj/mmlab/base/lib/python3.10/site-packages/mmcv/ops/nms.py:123: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.size(1) == 4
/home/yaadhavraaj/mmlab/base/lib/python3.10/site-packages/mmcv/ops/nms.py:124: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.size(0) == scores.size(0)
/home/yaadhavraaj/mmlab/base/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:5408: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  warnings.warn(
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/mmcv/ops/nms.py:87: FutureWarning: 'torch.onnx._patch_torch._graph_op' is deprecated in version 1.13 and will be removed in version 1.14. Please note 'g.op()' is to be removed from torch.Graph. Please open a GitHub issue if you need this functionality..
  max_output_boxes_per_class = g.op(
/home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/mmcv/ops/nms.py:101: FutureWarning: 'torch.onnx._patch_torch._graph_op' is deprecated in version 1.13 and will be removed in version 1.14. Please note 'g.op()' is to be removed from torch.Graph. Please open a GitHub issue if you need this functionality..
  return g.op('NonMaxSuppression', boxes, scores,
/home/yaadhavraaj/mmlab/base/lib/python3.10/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/yaadhavraaj/mmlab/base/lib/python3.10/site-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_graph_shape_type_inference(
/home/yaadhavraaj/mmlab/base/lib/python3.10/site-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_graph_shape_type_inference(
04/10 19:16:11 - mmengine - INFO - Execute onnx optimize passes.
04/10 19:16:12 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "backend_detectors" registry tree. As a workaround, the current "backend_detectors" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
04/10 19:16:12 - mmengine - INFO - Successfully loaded onnxruntime custom ops from /home/yaadhavraaj/mmlab/mmdeploy/mmdeploy/lib/libmmdeploy_onnxruntime_ops.so
2024-04-10 19:16:12.925876675 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 11 Memcpy nodes are added to the graph torch_jit for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.
2024-04-10 19:16:12.929131429 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-04-10 19:16:12.929140027 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
@soulslicer soulslicer changed the title [Bug] RTMDet model to ONNXRuntime doesn't produce masks [Bug] RTMDet model to ONNXRuntime or Torchscript doesn't produce masks Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant