tensorflow - Saving tf.trainable_variables() using convert_variables_to_constants -


i have keras model convert tensorflow protobuf (e.g. saved_model.pb).

this model comes transfer learning on vgg-19 network in , head cut-off , trained fully-connected+softmax layers while rest of vgg-19 network frozen

i can load model in keras, , use keras.backend.get_session() run model in tensorflow, generating correct predictions:

frame = preprocess(cv2.imread("path/to/img.jpg") keras_model = keras.models.load_model("path/to/keras/model.h5")  keras_prediction = keras_model.predict(frame)  print(keras_prediction)  keras.backend.get_session() sess:      tvars = tf.trainable_variables()      output = sess.graph.get_tensor_by_name('softmax:0')     input_tensor = sess.graph.get_tensor_by_name('input_1:0')      tf_prediction = sess.run(output, {input_tensor: frame})     print(tf_prediction) # matches keras_prediction 

if don't include line tvars = tf.trainable_variables(), tf_prediction variable wrong , doesn't match output keras_prediction @ all. in fact values in output (single array 4 probability values) same (~0.25, adding 1). made me suspect weights head initialized 0 if tf.trainable_variables() not called first, confirmed after inspecting model variables. in case, calling tf.trainable_variables() causes tensorflow prediction correct.

the problem when try save model, variables tf.trainable_variables() don't saved .pb file:

with keras.backend.get_session() sess:     tvars = tf.trainable_variables()      constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['softmax'])     graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=false) 

what asking is, how can save keras model tensorflow protobuf tf.training_variables() intact?

thanks much!

so approach of freezing variables in graph (converting constants), should work, isn't necessary , trickier other approaches. (more on below). if want graph freezing reason (e.g. exporting mobile device), i'd need more details debug, i'm not sure implicit stuff keras doing behind scenes graph. however, if want save , load graph later, can explain how that, (though no guarantees whatever keras doing won't screw up..., happy debug that).

so there 2 formats @ play here. 1 graphdef, used checkpointing, not contain metadata inputs , outputs. other metagraphdef contains metadata , graph def, metadata being useful prediction , running modelserver (from tensorflow/serving).

in either case need more call graph_io.write_graph because variables usually stored outside graphdef.

there wrapper libraries both these use cases. tf.train.saver used saving , restoring checkpoints.

however, since want prediction, suggest using tf.saved_model.builder.savedmodelbuilder build savedmodel binary. i've provided boiler plate below:

from tensorflow.python.saved_model.signature_constants import default_serving_signature_def_key default_sig_def builder = tf.saved_model.builder.savedmodelbuilder('./mymodel') keras.backend.get_session() sess:   output = sess.graph.get_tensor_by_name('softmax:0')   input_tensor = sess.graph.get_tensor_by_name('input_1:0')   sig_def = tf.saved_model.signature_def_utils.predict_signature_def(     {'input': input_tensor},     {'output': output}   )   builder.add_meta_graph_and_variables(       sess, tf.saved_model.tag_constants.serving,       signature_def_map={         default_sig_def: sig_def       }   ) builder.save() 

after running code should have mymodel/saved_model.pb file directory mymodel/variables/ protobufs corresponding variable values.

then load model again, use tf.saved_model.loader:

# keras give ability start fresh graph? # if not you'll need in separate program avoid # conflicts old default graph tf.session(graph=tf.graph()):   meta_graph_def = tf.saved_model.loader.load(       sess,        tf.saved_model.tag_constants.serving,       './mymodel'   )   # point variables , graph structure restored    sig_def = meta_graph_def.signature_def[default_sig_def]   print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame})) 

obviously there's more efficient prediction available code through tensorflow/serving, or cloud ml engine, should work. it's possible keras doing under hood interfere process well, , if we'd hear (and i'd make sure keras users able freeze graphs well, if want send me gist full code or maybe can find knows keras me debug.)

edit: can find end end example of here: https://github.com/googlecloudplatform/cloudml-samples/blob/master/census/keras/trainer/model.py#l85


Comments

Popular posts from this blog

ubuntu - PHP script to find files of certain extensions in a directory, returns populated array when run in browser, but empty array when run from terminal -

php - How can i create a user dashboard -

javascript - How to detect toggling of the fullscreen-toolbar in jQuery Mobile? -