2019-05-06 02:58:26 +02:00
|
|
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
|
|
def load_graph(model_file):
|
|
|
|
graph = tf.Graph()
|
|
|
|
graph_def = tf.GraphDef()
|
|
|
|
|
|
|
|
with open(model_file, "rb") as f:
|
|
|
|
graph_def.ParseFromString(f.read())
|
|
|
|
with graph.as_default():
|
|
|
|
tf.import_graph_def(graph_def)
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
def read_tensor_from_image_file(file_name,
|
|
|
|
input_height=299,
|
|
|
|
input_width=299,
|
|
|
|
input_mean=0,
|
|
|
|
input_std=255):
|
|
|
|
input_name = "file_reader"
|
|
|
|
output_name = "normalized"
|
|
|
|
file_reader = tf.read_file(file_name, input_name)
|
|
|
|
if file_name.endswith(".png"):
|
|
|
|
image_reader = tf.image.decode_png(
|
|
|
|
file_reader, channels=3, name="png_reader")
|
|
|
|
elif file_name.endswith(".gif"):
|
|
|
|
image_reader = tf.squeeze(
|
|
|
|
tf.image.decode_gif(file_reader, name="gif_reader"))
|
|
|
|
elif file_name.endswith(".bmp"):
|
|
|
|
image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
|
|
|
|
else:
|
|
|
|
image_reader = tf.image.decode_jpeg(
|
|
|
|
file_reader, channels=3, name="jpeg_reader")
|
|
|
|
float_caster = tf.cast(image_reader, tf.float32)
|
|
|
|
dims_expander = tf.expand_dims(float_caster, 0)
|
|
|
|
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
|
|
|
|
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
|
|
|
|
sess = tf.Session()
|
|
|
|
result = sess.run(normalized)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def load_labels(label_file):
|
|
|
|
label = []
|
|
|
|
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
|
|
|
|
for l in proto_as_ascii_lines:
|
|
|
|
label.append(l.rstrip())
|
|
|
|
return label
|
|
|
|
|
2019-06-09 23:18:35 +02:00
|
|
|
def classify(model_file="Model/graph.pb",
|
|
|
|
label_file="Model/graph_labels.txt",
|
2019-06-09 22:49:49 +02:00
|
|
|
input_height=299,
|
|
|
|
input_width=299,
|
|
|
|
input_mean=128,
|
|
|
|
input_std=128,
|
2019-06-09 23:18:35 +02:00
|
|
|
input_layer="Mul", #"input",
|
2019-06-09 22:49:49 +02:00
|
|
|
output_layer="final_result"): # "InceptionV3/Predictions/Reshape_1"):
|
|
|
|
"""Returns list of tuples consisting of name of file, category and certainity (0 - 1)"""
|
|
|
|
graph = load_graph(model_file)
|
|
|
|
|
|
|
|
files = []
|
|
|
|
for filename in os.listdir('Images/TestImages'):
|
|
|
|
t = read_tensor_from_image_file(
|
|
|
|
f'Images/TestImages/{filename}',
|
|
|
|
input_height=input_height,
|
|
|
|
input_width=input_width,
|
|
|
|
input_mean=input_mean,
|
|
|
|
input_std=input_std)
|
|
|
|
input_name = "import/" + input_layer
|
|
|
|
output_name = "import/" + output_layer
|
|
|
|
input_operation = graph.get_operation_by_name(input_name)
|
|
|
|
output_operation = graph.get_operation_by_name(output_name)
|
|
|
|
|
|
|
|
with tf.Session(graph=graph) as sess:
|
|
|
|
results = sess.run(output_operation.outputs[0], {
|
|
|
|
input_operation.outputs[0]: t
|
|
|
|
})
|
|
|
|
results = np.squeeze(results)
|
|
|
|
|
|
|
|
top_k = results.argsort()[-5:][::-1]
|
|
|
|
labels = load_labels(label_file)
|
|
|
|
|
|
|
|
files.append((filename, labels[top_k[0]], results[top_k[0]]))
|
|
|
|
print(f'{filename}: {labels[top_k[0]]} with {results[top_k[0]] * 100}% certainity')
|
|
|
|
return files
|
|
|
|
|
2019-05-06 02:58:26 +02:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2019-06-09 22:23:39 +02:00
|
|
|
model_file = "Model/graph.pb"
|
2019-06-09 23:18:35 +02:00
|
|
|
label_file = "Model/graph_labels.txt"
|
2019-05-06 02:58:26 +02:00
|
|
|
input_height = 299
|
|
|
|
input_width = 299
|
|
|
|
input_mean = 128
|
|
|
|
input_std = 128
|
|
|
|
input_layer = "input"
|
|
|
|
output_layer = "InceptionV3/Predictions/Reshape_1"
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
parser.add_argument("--graph",
|
2019-06-09 22:23:39 +02:00
|
|
|
default="Model/graph.pb",
|
2019-05-06 02:58:26 +02:00
|
|
|
help="graph/model to be executed")
|
|
|
|
|
|
|
|
parser.add_argument("--labels",
|
2019-06-09 22:23:39 +02:00
|
|
|
default="Model/graph_labels.txt",
|
2019-05-06 02:58:26 +02:00
|
|
|
help="name of file containing labels")
|
|
|
|
|
|
|
|
parser.add_argument("--input_height", type=int, help="input height")
|
|
|
|
|
|
|
|
parser.add_argument("--input_width", type=int, help="input width")
|
|
|
|
|
|
|
|
parser.add_argument("--input_mean", type=int, help="input mean")
|
|
|
|
|
|
|
|
parser.add_argument("--input_std", type=int, help="input std")
|
|
|
|
|
|
|
|
parser.add_argument("--input_layer",
|
2019-06-09 22:23:39 +02:00
|
|
|
default="Mul",
|
2019-05-06 02:58:26 +02:00
|
|
|
help="name of input layer")
|
|
|
|
|
|
|
|
parser.add_argument("--output_layer",
|
|
|
|
default="final_result",
|
|
|
|
help="name of output layer")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.graph:
|
|
|
|
model_file = args.graph
|
|
|
|
if args.labels:
|
|
|
|
label_file = args.labels
|
|
|
|
if args.input_height:
|
|
|
|
input_height = args.input_height
|
|
|
|
if args.input_width:
|
|
|
|
input_width = args.input_width
|
|
|
|
if args.input_mean:
|
|
|
|
input_mean = args.input_mean
|
|
|
|
if args.input_std:
|
|
|
|
input_std = args.input_std
|
|
|
|
if args.input_layer:
|
|
|
|
input_layer = args.input_layer
|
|
|
|
if args.output_layer:
|
|
|
|
output_layer = args.output_layer
|
|
|
|
|
2019-06-09 23:18:35 +02:00
|
|
|
classify(model_file=model_file, label_file=label_file, input_height=input_height, input_width=input_width,
|
2019-06-09 22:49:49 +02:00
|
|
|
input_mean=input_mean, input_std=input_std, input_layer=input_layer, output_layer=output_layer)
|
|
|
|
|
2019-05-06 02:58:26 +02:00
|
|
|
# for i in top_k:
|
|
|
|
# print(labels[i], results[i])
|
|
|
|
|