python - Saving model in Tensorflow not working under GPU? -
update: i've found out below code work correctly when using tensorflow-cpu. problem persists when using tensorflow-gpu. how can make work?
i cannot find problem in code - trying save variables, , reload them, , don't appear load saved model.
i note load if saving , loading in same python run (without process ending , running testing script). problem doesn't work when train mode -> save -> process ends -> run script again testing flag -> model loaded without error, results if wasn't.
code:
run #1
# creating lstm model... tf.session() sess: saver = tf.train.saver() # training... save_path = saver.save(sess, "./saved_models/model.ckpt") print("model saved in file: %s" % save_path)
run #2
# creating same exact lstm model... tf.session() sess: saver = tf.train.saver() saver.restore(sess, "./saved_models/model.ckpt") print("model restored.") # testing...
if run these 2 snippets back, desired output - model trained predict trivial sequence, , predicts during testing. if run 2 snippets separately, model predicts wrong sequence during testing.
update: suggested try importing metagraph , it's not working either. code:
run #1
# creating model... tf.add_to_collection('a', net.a) # adding nodes ... tf.add_to_collection('z', net.z) tf.session() sess: saver = tf.train.saver() # training... save_path = saver.save(sess, "./saved_models/my-model") print("model saved in file: %s" % save_path)
run #2
with tf.session() sess: new_saver = tf.train.import_meta_graph('./saved_models/my-model.meta') new_saver.restore(sess, './saved_models/my-model') net.a = tf.get_collection('a')[0] # adding nodes ... net.z = tf.get_collection('z')[0] # testing...
the above code runs correctly - testset result shows not post-training (and again, if run 2 snippets in same python instance, works correctly).
this should trivial , cannot work. welcome. specifically, don't need save entire graph either - variables (some of them inside lstm cell).
i've encountered same problem, , guess use tf.variable()
, right? try change tf.get_variable()
. worked me :)
Comments
Post a Comment