python - tensorflow java api err: java.lang.IllegalStateException: Tensor is not a scalar -
i trying load pretrained model (using python) java project.
the problem
exception in thread "thread-9" java.lang.illegalstateexception: tensor not scalar @ org.tensorflow.tensor.scalarfloat(native method) @ org.tensorflow.tensor.floatvalue(tensor.java:279)
code
float[] arr=context.csvintarr(context.getplayer(playerid)); float[][] martix={arr}; try (graph g=model.graph()){ try(session s=model.session()){ tensor y=s.runner().feed("input/input", tensor.create(martix)) .fetch("out/predict").run().get(0); logger.info("a {}",y.floatvalue()); } }
the python code train , save model
with tf.session() sess: tf.name_scope('input'): x=tf.placeholder(tf.float32,[none,bucketlen],name="input") ...... tf.name_scope('out'): y=tf.tanh(tf.matmul(h,hw)+hb,name="predict") builder=tf.saved_model.builder.savedmodelbuilder(export_dir) builder.add_meta_graph_and_variables(sess,['foo-tag']) ......after train process builder.save()
it seems have loaded model , graph,because
try (graph g=model.graph()){ try(session s=model.session()){ operation operation=g.operation("input/input"); logger.info(operation.name()); } }
print out name successfully.
the error message indicates output tensor isn't float-valued scalar, it's higher dimension tensor (a vector, matrix).
you can learn shape of tensor using system.out.println(y.tostring())
or using y.shape()
. in python code, correspond y.shape
.
for non-scalars, use y.copyto
array of floats (for vector), or array of array of floats (for matrix) etc.
for example, like:
system.out.println(y); // if above printed like: // "float tensor shape [1]" // can values using: float[] vector = y.copyto(new float[1]); // if shape [2, 3] // can values using: float[][] matrix = y.copyto(new float[2][3]);
see tensor
javadoc more information on floatvalue()
vs copyto
vs writeto
.
hope helps.
Comments
Post a Comment