Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low latency than expected #27

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
2a538f1
Fix installation scripts (using python3)
jkjung-avt Sep 12, 2018
00cb45f
Add more stuffs into .gitignore
jkjung-avt Sep 13, 2018
68244a0
Add logs/ into .gitignore
jkjung-avt Sep 14, 2018
8d5b726
Add camera_tf_trt.py script and the corresponding utils code
jkjung-avt Sep 14, 2018
4517702
Replace NVIDIA's README.md wiht my own stuffs, including description …
jkjung-avt Sep 14, 2018
ac699d0
Minor updates to README.md
jkjung-avt Sep 14, 2018
21eee23
Merge pull request #1 from jkjung-avt/dev
jkjung-avt Sep 14, 2018
f1cf752
Fix typos
jkjung-avt Sep 14, 2018
1ec645e
Update some comments in camera_tf_trt.py
jkjung-avt Sep 14, 2018
7db4004
Add utils/__init__.py so as to fix problems importing stuffs from tha…
jkjung-avt Sep 16, 2018
1070d9a
Create the BBoxVisualization class and implement a nicer bounding box…
jkjung-avt Sep 16, 2018
c7dace3
Fix bugs in the installation scripts
jkjung-avt Sep 17, 2018
b56ecbf
Make sure BBoxVisualization is working, and create test code for it
jkjung-avt Sep 17, 2018
bff3967
Add reference link about setting max_batch_size to avoid 'cudnnFusedC…
jkjung-avt Sep 17, 2018
cff7f4f
Use the new BBoxVisualization class to draw nicer-looking bounding bo…
jkjung-avt Sep 17, 2018
9813e81
Merge pull request #2 from jkjung-avt/dev
jkjung-avt Sep 17, 2018
d9d9cc5
Fix the issue of CPU getting occupied by grab_img thread in use_image…
jkjung-avt Sep 17, 2018
c585a44
Set display window size based on actual input image size (not the wid…
jkjung-avt Sep 17, 2018
15f0858
Update the screenshot based on test result with JetPack-3.2 and tenso…
jkjung-avt Sep 17, 2018
ff30619
Add the tensorflow SSD model config files (don't download the origina…
jkjung-avt Sep 17, 2018
a119d7b
Merge pull request #3 from jkjung-avt/dev
jkjung-avt Sep 17, 2018
e9281ab
Add link to tensorflow 1.8.0 wheel for JetPack-3.3 (built by myself)
jkjung-avt Sep 19, 2018
263ea8a
Add highlights on the pip wheel download links
jkjung-avt Sep 19, 2018
8526e04
Minor udpates on text formatting/alignments
jkjung-avt Sep 19, 2018
8fd6429
Merge pull request #4 from jkjung-avt/dev
jkjung-avt Sep 19, 2018
fbb14a3
Fix the bug of class 0 (output of TensorFlow Object Detection models …
jkjung-avt Sep 25, 2018
44f3f47
Add support for ssd_mobilenet_v1_egohands
jkjung-avt Sep 25, 2018
e99a9ac
Merge pull request #5 from jkjung-avt/dev
jkjung-avt Sep 25, 2018
e265da7
Add data/egohands_label_map.pbtxt
jkjung-avt Sep 25, 2018
c6e8e66
Merge pull request #6 from jkjung-avt/dev
jkjung-avt Sep 25, 2018
3bfcedd
Fix class 0 issue
jkjung-avt Sep 26, 2018
b5b54d4
Merge pull request #7 from jkjung-avt/dev
jkjung-avt Sep 26, 2018
383c09a
Add support for 'faster_rcnn_resnet50_egohands' model
jkjung-avt Sep 28, 2018
efee38e
Have a hacky working version of TF-TRT optimized 'faster_rcnn_incepti…
jkjung-avt Oct 1, 2018
427f43a
Add support for 'ssd_mobilenet_v2_egohands', 'ssdlite_mobilenet_v2_eg…
jkjung-avt Oct 2, 2018
683c803
Merge pull request #8 from jkjung-avt/dev
jkjung-avt Oct 2, 2018
27e25d7
Add description about applying TF-TRT on the hand detector models
jkjung-avt Oct 2, 2018
846c19f
Put faster_rcnn SecondStage computations (though not optimized by TF-…
jkjung-avt Oct 3, 2018
22ede28
Use tensor names that are coded in the original TensorFlow Object Det…
jkjung-avt Oct 8, 2018
aa5079b
Remove '--num-classes' option, which could be derived from the label …
jkjung-avt Oct 8, 2018
5956941
Add code to deal with missing classes in the label map
jkjung-avt Oct 8, 2018
a9ab4b9
Reduce number of region proposals to 32 for all faster_rcnn models
jkjung-avt Oct 8, 2018
26561f0
Add support for 'rfcn_resnet101_egohands' model
jkjung-avt Oct 8, 2018
8076a7a
Add code to handle rfcn models
jkjung-avt Oct 8, 2018
300437a
Add code to handle rfcn models, as well as some minor optimizations
jkjung-avt Oct 8, 2018
af1fbf0
Merge pull request #9 from jkjung-avt/dev
jkjung-avt Oct 8, 2018
868cdf5
Attempt to merge NVIDIA's latest changes into my own repository
jkjung-avt Dec 13, 2018
513f4f0
Update tensorflow 'models' to a newer snapshot (hash: 6518c1c)
jkjung-avt Dec 13, 2018
a92bffa
Remove coco model configs (will download the latest files from Tensor…
jkjung-avt Dec 13, 2018
ea70984
Fix errors in the installation scripts
jkjung-avt Dec 13, 2018
9feea6b
Fix 'config_path' related code
jkjung-avt Dec 13, 2018
1820836
Fix class dictionary (cls_dict) indices: it is 1-based with the newer…
jkjung-avt Dec 14, 2018
b0d088c
Add 'force_2ndstage_cpu()' jack to make faster_rcnn_xxx and rfcn_xxx …
jkjung-avt Dec 14, 2018
ff6cbeb
Add a python3 related fix in object_detection/models/feature_map_gene…
jkjung-avt Jan 2, 2019
030f655
Merge pull request #10 from jkjung-avt/nvidia
jkjung-avt Jan 3, 2019
f7f2fa5
Add description about tensorflow version and the 'TF-TRT Revisted' post
jkjung-avt May 27, 2019
6d2294e
Do single-threading when reading from image or video files
jkjung-avt Jun 17, 2019
48a111d
Remove unused code
jkjung-avt Jun 20, 2019
321b349
Fix a minor bug in visualization.py
jkjung-avt Oct 25, 2019
0c811f4
Add support for 'nvarguscamerasrc'
jkjung-avt Nov 20, 2019
2f55c69
Highlight the TF-TRT Revisited blog post
jkjung-avt Nov 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__pycache__
data/protoc
build/
dist/
tf_trt_models.egg-info/
data/
logs/
309 changes: 111 additions & 198 deletions README.md

Large diffs are not rendered by default.

199 changes: 199 additions & 0 deletions camera_tf_trt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""camera_tf_trt.py

This is a Camera TensorFlow/TensorRT Object Detection sample code for
Jetson TX2 or TX1. This script captures and displays video from either
a video file, an image file, an IP CAM, a USB webcam, or the Tegra
onboard camera, and do real-time object detection with example TensorRT
optimized SSD models in NVIDIA's 'tf_trt_models' repository. Refer to
README.md inside this repository for more information.

This code is written and maintained by JK Jung <jkjung13@gmail.com>.
"""


import sys
import time
import logging
import argparse

import numpy as np
import cv2
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt

from utils.camera import add_camera_args, Camera
from utils.od_utils import read_label_map, build_trt_pb, load_trt_pb, \
write_graph_tensorboard, detect
from utils.visualization import BBoxVisualization


# Constants
DEFAULT_MODEL = 'ssd_inception_v2_coco'
DEFAULT_LABELMAP = 'third_party/models/research/object_detection/' \
'data/mscoco_label_map.pbtxt'
WINDOW_NAME = 'CameraTFTRTDemo'
BBOX_COLOR = (0, 255, 0) # green


def parse_args():
"""Parse input arguments."""
desc = ('This script captures and displays live camera video, '
'and does real-time object detection with TF-TRT model '
'on Jetson TX2/TX1/Nano')
parser = argparse.ArgumentParser(description=desc)
parser = add_camera_args(parser)
parser.add_argument('--model', dest='model',
help='tf-trt object detecion model '
'[{}]'.format(DEFAULT_MODEL),
default=DEFAULT_MODEL, type=str)
parser.add_argument('--build', dest='do_build',
help='re-build TRT pb file (instead of using'
'the previously built version)',
action='store_true')
parser.add_argument('--tensorboard', dest='do_tensorboard',
help='write optimized graph summary to TensorBoard',
action='store_true')
parser.add_argument('--labelmap', dest='labelmap_file',
help='[{}]'.format(DEFAULT_LABELMAP),
default=DEFAULT_LABELMAP, type=str)
parser.add_argument('--num-classes', dest='num_classes',
help='(deprecated and not used) number of object '
'classes', type=int)
parser.add_argument('--confidence', dest='conf_th',
help='confidence threshold [0.3]',
default=0.3, type=float)
args = parser.parse_args()
return args


def open_display_window(width, height):
"""Open the cv2 window for displaying images with bounding boxeses."""
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
cv2.resizeWindow(WINDOW_NAME, width, height)
cv2.moveWindow(WINDOW_NAME, 0, 0)
cv2.setWindowTitle(WINDOW_NAME, 'Camera TFTRT Object Detection Demo '
'for Jetson TX2/TX1')


def draw_help_and_fps(img, fps):
"""Draw help message and fps number at top-left corner of the image."""
help_text = "'Esc' to Quit, 'H' for FPS & Help, 'F' for Fullscreen"
font = cv2.FONT_HERSHEY_PLAIN
line = cv2.LINE_AA

fps_text = 'FPS: {:.1f}'.format(fps)
cv2.putText(img, help_text, (11, 20), font, 1.0, (32, 32, 32), 4, line)
cv2.putText(img, help_text, (10, 20), font, 1.0, (240, 240, 240), 1, line)
cv2.putText(img, fps_text, (11, 50), font, 1.0, (32, 32, 32), 4, line)
cv2.putText(img, fps_text, (10, 50), font, 1.0, (240, 240, 240), 1, line)
return img


def set_full_screen(full_scrn):
"""Set display window to full screen or not."""
prop = cv2.WINDOW_FULLSCREEN if full_scrn else cv2.WINDOW_NORMAL
cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, prop)


def loop_and_detect(cam, tf_sess, conf_th, vis, od_type):
"""Loop, grab images from camera, and do object detection.

# Arguments
cam: the camera object (video source).
tf_sess: TensorFlow/TensorRT session to run SSD object detection.
conf_th: confidence/score threshold for object detection.
vis: for visualization.
"""
show_fps = True
full_scrn = False
fps = 0.0
tic = time.time()
while True:
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0:
# Check to see if the user has closed the display window.
# If yes, terminate the while loop.
break

img = cam.read()
if img is not None:
box, conf, cls = detect(img, tf_sess, conf_th, od_type=od_type)
img = vis.draw_bboxes(img, box, conf, cls)
if show_fps:
img = draw_help_and_fps(img, fps)
cv2.imshow(WINDOW_NAME, img)
toc = time.time()
curr_fps = 1.0 / (toc - tic)
# calculate an exponentially decaying average of fps number
fps = curr_fps if fps == 0.0 else (fps*0.9 + curr_fps*0.1)
tic = toc

key = cv2.waitKey(1)
if key == 27: # ESC key: quit program
break
elif key == ord('H') or key == ord('h'): # Toggle help/fps
show_fps = not show_fps
elif key == ord('F') or key == ord('f'): # Toggle fullscreen
full_scrn = not full_scrn
set_full_screen(full_scrn)


def main():
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Ask tensorflow logger not to propagate logs to parent (which causes
# duplicated logging)
logging.getLogger('tensorflow').propagate = False

args = parse_args()
logger.info('called with args: %s' % args)

# build the class (index/name) dictionary from labelmap file
logger.info('reading label map')
cls_dict = read_label_map(args.labelmap_file)

pb_path = './data/{}_trt.pb'.format(args.model)
log_path = './logs/{}_trt'.format(args.model)
if args.do_build:
logger.info('building TRT graph and saving to pb: %s' % pb_path)
build_trt_pb(args.model, pb_path)

logger.info('opening camera device/file')
cam = Camera(args)
cam.open()
if not cam.is_opened:
sys.exit('Failed to open camera!')

logger.info('loading TRT graph from pb: %s' % pb_path)
trt_graph = load_trt_pb(pb_path)

logger.info('starting up TensorFlow session')
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_sess = tf.Session(config=tf_config, graph=trt_graph)

if args.do_tensorboard:
logger.info('writing graph summary to TensorBoard')
write_graph_tensorboard(tf_sess, log_path)

logger.info('warming up the TRT graph with a dummy image')
od_type = 'faster_rcnn' if 'faster_rcnn' in args.model else 'ssd'
dummy_img = np.zeros((720, 1280, 3), dtype=np.uint8)
_, _, _ = detect(dummy_img, tf_sess, conf_th=.3, od_type=od_type)

cam.start() # ask the camera to start grabbing images

# grab image and do object detection (until stopped by user)
logger.info('starting to loop and detect')
vis = BBoxVisualization(cls_dict)
open_display_window(cam.img_width, cam.img_height)
loop_and_detect(cam, tf_sess, args.conf_th, vis, od_type=od_type)

logger.info('cleaning up')
cam.stop() # terminate the sub-thread in camera
tf_sess.close()
cam.release()
cv2.destroyAllWindows()


if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions data/egohands_label_map.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
item {
id: 1
name: 'hand'
}
146 changes: 146 additions & 0 deletions data/faster_rcnn_inception_v2_egohands.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Faster R-CNN with Inception v2, configured for egohands dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.

model {
faster_rcnn {
num_classes: 1
image_resizer {
#keep_aspect_ratio_resizer {
# min_dimension: 600
# max_dimension: 1024
#}
# Use fixed input image dimension, refer to:
# https://github.com/NVIDIA-Jetson/tf_trt_models/issues/6#issuecomment-423098067
fixed_shape_resizer {
height: 576
width: 1024
}
}
feature_extractor {
type: 'faster_rcnn_inception_v2'
first_stage_features_stride: 16
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
#first_stage_max_proposals: 300
first_stage_max_proposals: 32
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
#max_detections_per_class: 100
max_detections_per_class: 32
#max_total_detections: 300
max_total_detections: 32
}
score_converter: SOFTMAX
}
second_stage_batch_size: 32
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}

train_config: {
batch_size: 1
optimizer {
momentum_optimizer: {
learning_rate: {
manual_step_learning_rate {
initial_learning_rate: 0.0002
schedule {
step: 30000
learning_rate: .00002
}
schedule {
step: 48000
learning_rate: .000002
}
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "faster_rcnn_inception_v2_coco_2018_01_28/model.ckpt"
from_detection_checkpoint: true
num_steps: 50000
data_augmentation_options {
random_horizontal_flip {
}
}
}

train_input_reader: {
tf_record_input_reader {
input_path: "data/egohands_train.tfrecord"
}
label_map_path: "data/egohands_label_map.pbtxt"
}

eval_config: {
num_examples: 500
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}

eval_input_reader: {
tf_record_input_reader {
input_path: "data/egohands_val.tfrecord"
}
label_map_path: "data/egohands_label_map.pbtxt"
shuffle: false
num_readers: 1
}
9 changes: 9 additions & 0 deletions data/faster_rcnn_inception_v2_egohands/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# faster_rcnn_inception_v2_egohands

Copy your own trained 'faster_rcnn_inception_v2_egohands' model checkpoint files into this directory:

```
model.ckpt-50000.data-00000-of-00001
model.ckpt-50000.index
model.ckpt-50000.meta
```