machine learning - Freezing Model drops Output Accuracy -


i have image segmentation network designed classify roads , obstacles. want freeze model , serve api. used default tensorflow tool freezing model. after freezing, output given network off , inaccurate.

here 1 sample.

the input image

enter image description here

output when tested using checkpoint files enter image description here

output after freezing model enter image description here

i have tried freeze using different versions of tensorflow, has not helped. since network performing excepted when tested against checkpoint, issue, think in freeze models script. network uses batch_normalisation. reason drop because saw couple of issues linked of similar nature? how can avoid that?

here's full network

prediction using checkpoint files

with tf.graph().as_default() graph:     images_tensor = tf.train.string_input_producer(images_list, shuffle=false)     reader = tf.wholefilereader()     key, image_tensor = reader.read(images_tensor)     image = tf.image.decode_png(image_tensor, channels=3)     image = preprocess(image)     images = tf.train.batch([image], batch_size = 1, allow_smaller_final_batch=true)      #create model inference     slim.arg_scope(enet_arg_scope()):         logits, probabilities = enet(images,                                      num_classes=4,                                      batch_size=1,                                      is_training=true,                                      reuse=none,                                      num_initial_blocks=num_initial_blocks,                                      stage_two_repeat=stage_two_repeat,                                      skip_connections=skip_connections)      variables_to_restore = slim.get_variables_to_restore()     saver = tf.train.saver(variables_to_restore)     def restore_fn(sess):         return saver.restore(sess, checkpoint)     predictions = tf.argmax(probabilities, -1)     predictions = tf.cast(predictions, tf.float32)     sv = tf.train.supervisor(logdir=none, init_fn=restore_fn)     sv.managed_session() sess:         in xrange(int(len(images_list) / 1 + 1)):             segmentations = sess.run(predictions)             j in xrange(segmentations.shape[0]):                 converted_image = grayscale_to_colour(segmentations[j],i,j)                 imsave(photo_dir + "/imagelabel_%05d_edges.png" %(i*1 + j), converted_image) 

prediction pb file

def predict():     start = time.time()     y_out = persistent_sess.run(y, feed_dict={x: x_in})     end = time.time()     print(end-start)     return y_out  tf.session() sess:     model_filename = "frozen_model_tf_version.pb"     gfile.fastgfile(model_filename, 'rb') f:         graph_def = tf.graphdef()         graph_def.parsefromstring(f.read())         tf.import_graph_def(graph_def)         g_in = tf.get_default_graph()  x = g_in.get_tensor_by_name('import/batch:0') y = g_in.get_tensor_by_name('import/enet/output:0')  persistent_sess = tf.session(graph=g_in) x_in_unaltered=cv2.imread(img) x_in_unaltered = cv2.resize(x_in_unaltered,(480,360),interpolation=cv2.inter_cubic) x_in = np.expand_dims(x_in_unaltered.flatten(),axis=0) predictions=predict() print(np.unique(predictions,return_counts=true)) out = np.array(predictions[0],dtype=np.float32) out = np.reshape(out, [360,480]) converted_image = grayscale_to_colour(out,x_in_unaltered) cv2.imwrite("out.png",converted_image) 

here problem related is_training, since using dropout , batch_norm, during prediction time is_training shoudl set false. can expect same results.

logits, probabilities = enet(images,                                  num_classes=4,                                  batch_size=1,                                  is_training=false,                                  reuse=none,                                  num_initial_blocks=num_initial_blocks,                                  stage_two_repeat=stage_two_repeat,                                  skip_connections=skip_connections) 

Comments

Popular posts from this blog

python - Operations inside variables -

Generic Map Parameter java -

arrays - What causes a java.lang.ArrayIndexOutOfBoundsException and how do I prevent it? -