python - Saving model in Tensorflow not working under GPU? -


update: i've found out below code work correctly when using tensorflow-cpu. problem persists when using tensorflow-gpu. how can make work?

i cannot find problem in code - trying save variables, , reload them, , don't appear load saved model.

i note load if saving , loading in same python run (without process ending , running testing script). problem doesn't work when train mode -> save -> process ends -> run script again testing flag -> model loaded without error, results if wasn't.

code:

run #1

# creating lstm model...  tf.session() sess:     saver = tf.train.saver()      # training...      save_path = saver.save(sess, "./saved_models/model.ckpt")     print("model saved in file: %s" % save_path) 

run #2

# creating same exact lstm model...  tf.session() sess:     saver = tf.train.saver()      saver.restore(sess, "./saved_models/model.ckpt")     print("model restored.")      # testing... 

if run these 2 snippets back, desired output - model trained predict trivial sequence, , predicts during testing. if run 2 snippets separately, model predicts wrong sequence during testing.

update: suggested try importing metagraph , it's not working either. code:

run #1

# creating model...  tf.add_to_collection('a', net.a) # adding nodes ... tf.add_to_collection('z', net.z)  tf.session() sess:     saver = tf.train.saver()     # training...     save_path = saver.save(sess, "./saved_models/my-model")     print("model saved in file: %s" % save_path) 

run #2

with tf.session() sess:     new_saver = tf.train.import_meta_graph('./saved_models/my-model.meta')     new_saver.restore(sess, './saved_models/my-model')      net.a = tf.get_collection('a')[0]     # adding nodes ...     net.z = tf.get_collection('z')[0]      # testing... 

the above code runs correctly - testset result shows not post-training (and again, if run 2 snippets in same python instance, works correctly).

this should trivial , cannot work. welcome. specifically, don't need save entire graph either - variables (some of them inside lstm cell).

i've encountered same problem, , guess use tf.variable(), right? try change tf.get_variable(). worked me :)


Comments

Popular posts from this blog

ubuntu - PHP script to find files of certain extensions in a directory, returns populated array when run in browser, but empty array when run from terminal -

php - How can i create a user dashboard -

javascript - How to detect toggling of the fullscreen-toolbar in jQuery Mobile? -