python - tensorflow java api err: java.lang.IllegalStateException: Tensor is not a scalar -


i trying load pretrained model (using python) java project.

the problem

 exception in thread "thread-9" java.lang.illegalstateexception: tensor not scalar     @ org.tensorflow.tensor.scalarfloat(native method)     @ org.tensorflow.tensor.floatvalue(tensor.java:279) 

code

    float[] arr=context.csvintarr(context.getplayer(playerid));     float[][] martix={arr};     try (graph g=model.graph()){         try(session s=model.session()){              tensor y=s.runner().feed("input/input", tensor.create(martix))             .fetch("out/predict").run().get(0);             logger.info("a {}",y.floatvalue());         }     } 

the python code train , save model

with tf.session() sess:     tf.name_scope('input'):         x=tf.placeholder(tf.float32,[none,bucketlen],name="input") ......     tf.name_scope('out'):         y=tf.tanh(tf.matmul(h,hw)+hb,name="predict")     builder=tf.saved_model.builder.savedmodelbuilder(export_dir)     builder.add_meta_graph_and_variables(sess,['foo-tag']) ......after train process      builder.save() 

it seems have loaded model , graph,because

  try (graph g=model.graph()){         try(session s=model.session()){             operation operation=g.operation("input/input");             logger.info(operation.name());          }     } 

print out name successfully.

the error message indicates output tensor isn't float-valued scalar, it's higher dimension tensor (a vector, matrix).

you can learn shape of tensor using system.out.println(y.tostring()) or using y.shape(). in python code, correspond y.shape.

for non-scalars, use y.copyto array of floats (for vector), or array of array of floats (for matrix) etc.

for example, like:

system.out.println(y); // if above printed like: // "float tensor shape [1]" // can values using: float[] vector = y.copyto(new float[1]);  // if shape [2, 3] // can values using: float[][] matrix = y.copyto(new float[2][3]); 

see tensor javadoc more information on floatvalue() vs copyto vs writeto.

hope helps.


Comments

Popular posts from this blog

ubuntu - PHP script to find files of certain extensions in a directory, returns populated array when run in browser, but empty array when run from terminal -

php - How can i create a user dashboard -

javascript - How to detect toggling of the fullscreen-toolbar in jQuery Mobile? -