Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust NonMaxSuppression score_threshold in build_detection_graph for ssd networks #19

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion tf_trt_models/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import tarfile
import subprocess
import warnings

from google.protobuf import text_format

Expand Down Expand Up @@ -86,7 +87,7 @@ def download_detection_model(model, output_dir='.'):

return config_path, checkpoint_path

def build_detection_graph(config, checkpoint):
def build_detection_graph(config, checkpoint, score_threshold=0.3):
"""Build an object detection model from the TensorFlow model zoo.

This function creates an object detection model, sourced from the
Expand All @@ -112,6 +113,8 @@ def build_detection_graph(config, checkpoint):
:type config: string
:param checkpoint: path to the checkpoint files prefix containing trained model params
:type checkpoint: string
:score_threshold: NonMaxSuppression score_threshold (default 0.3)
:type score_threshold: float
:returns: the configured frozen graph representing object detection model
:rtype: a tensorflow GraphDef
"""
Expand All @@ -123,6 +126,13 @@ def build_detection_graph(config, checkpoint):
config = TrainEvalPipelineConfig()
text_format.Merge(config_str, config)

try:
old_score = config.model.ssd.post_processing.batch_non_max_suppression.score_threshold
config.model.ssd.post_processing.batch_non_max_suppression.score_threshold=score_threshold
warnings.warn("The score threshold of NonMaxSuppression was set from "+str(old_score)+" to "+str(score_threshold))
except AttributeError:
warnings.warn("The score threshold of NonMaxSuppression can not be reconfigured")
pass

tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
Expand Down