Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in onnx2trt_utils.cpp at line 2165 with the assertion: indicesDims.d[i] <= dataDims.d[i] #963

Open
ControllableGeneration opened this issue Apr 7, 2024 · 0 comments

Comments

@ControllableGeneration
Copy link

Description

for (int32_t i = 0; i < dataDims.nbDims; ++i) { if (indicesDims.d[i] != -1 && dataDims.d[i] != -1) { ASSERT(indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); } if (updatesDims.d[i] != -1 && dataDims.d[i] != -1) { ASSERT(updatesDims.d[i] <= dataDims.d[i] && "Updates dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); } }

In this section, the assertion indicesDims.d[i] <= dataDims.d[i] should be changed into indices.max() <= dataDims.d[i]. the assertion updatesDims.d[i] <= dataDims.d[i] should be changed into updates.max() <= dataDims.d[i].

This is because that scatter cares about the index instead of the values, hence, dataDims can be irrelevant with indicesDims/updatesDims, besides that their ranks should be the same.

Environment

TensorRT Version: 8.6.1.6
ONNX-TensorRT Version / Branch: 1.16
GPU Type: RTX 4070
Nvidia Driver Version: 535.146.02
CUDA Version: 12.2
CUDNN Version: 11.8
Operating System + Version: Ubuntu 18.4
Python Version (if applicable): 3.10
TensorFlow + TF2ONNX Version (if applicable):
PyTorch Version (if applicable): 2.2.1
Baremetal or Container (if container which image + tag):

Relevant Files

No relevant files.

Steps To Reproduce

You could reproduce it by converting the following into an onnx and then into a trt engine:
y = torch.scatter_add(input=y_, dim=1, index=x, src=x_)
where y_ is of the shape of (1, 20), x is of the shape of (1, 4096) with the maximum value <= 19, x_ is of the shape of (1, 4096) and filled with 1's.

You will see that it can be converted into onnx but fails at trt engine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant