Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How did you generate other model weights? #156

Open
agrija9 opened this issue Mar 14, 2022 · 0 comments
Open

How did you generate other model weights? #156

agrija9 opened this issue Mar 14, 2022 · 0 comments

Comments

@agrija9
Copy link

agrija9 commented Mar 14, 2022

Hello @jaybdub, I have been able to run the live_demo.ipynb on a Jetson Xavier NX with the two provided models (resnet18 and densenet121).

However, I need better accuracy for my application. When I try to run the script with e.g. resnet50, I have an _IncompatibleKeys(missing_keys, unexpected_keys) error.

The way I am downloading the model and trying to load the weights is the following:

import json
import trt_pose.coco
import torch
import trt_pose.models
import torch2trt
from torch2trt import TRTModule
import time
import cv2
import torchvision.transforms as transforms
import PIL.Image

with open('human_pose.json', 'r') as f:
    human_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(human_pose)

num_parts = len(human_pose['keypoints'])
num_links = len(human_pose['skeleton'])

# Downloads model into /home/jetson3/.cache//torch/hub/checkpoints
model = trt_pose.models.resnet50_baseline_att(num_parts, 2 * num_links).cuda().eval()

# Load model weights
MODEL_WEIGHTS = "/home/jetson3/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth"
model.load_state_dict(torch.load(MODEL_WEIGHTS))

This is when the error happens:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-1687e7d09b26> in <module>
      7 MODEL_WEIGHTS = "/home/jetson3/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth"
----> 8 model.load_state_dict(torch.load(MODEL_WEIGHTS))

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1481         if len(error_msgs) > 0:
   1482             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1484         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1485 

RuntimeError: Error(s) in loading state_dict for Sequential:

Appreciate your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant