python 2.7 - Retraining Incpetion v3 model without reshape layer -


i had retrained inception v3 model custom dataset. after retraining when @ tenosorgraph found layer named reshape followed connected layer added. have run model on embedded device using snapdragonneural processing engine(snpe) doesnt support reshape layer of run on dsp.

is there possible way of retraining inception v3 without adding reshape layer. below retrain code reshape layer added.

enter code here               def create_model_info(architecture):   """given name of model architecture, returns information it.    there different base image recognition pretrained models can   retrained using transfer learning, , function translates name   of model attributes needed download , train it.    args:     architecture: name of model architecture.    returns:     dictionary of information model, or none if name isn't     recognized    raises:     valueerror: if architecture name unknown.   """   architecture = architecture.lower()   if architecture == 'inception_v3':     # pylint: disable=line-too-long     data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'     # pylint: enable=line-too-long     bottleneck_tensor_name = 'pool_3/_reshape:0'     bottleneck_tensor_size = 2048     input_width = 299     input_height = 299     input_depth = 3     resized_input_tensor_name = 'mul:0'     model_file_name = 'classify_image_graph_def.pb'     input_mean = 128     input_std = 128       elif architecture.startswith('mobilenet_'):         parts = architecture.split('_')         if len(parts) != 3 , len(parts) != 4:           tf.logging.error("couldn't understand architecture name '%s'",                            architecture)           return none         version_string = parts[1]         if (version_string != '1.0' , version_string != '0.75' ,             version_string != '0.50' , version_string != '0.25'):           tf.logging.error(               """"the mobilenet version should '1.0', '0.75', '0.50', or '0.25',       found '%s' architecture '%s'""",               version_string, architecture)           return none         size_string = parts[2]         if (size_string != '224' , size_string != '192' ,             size_string != '160' , size_string != '128'):           tf.logging.error(               """the mobilenet input size should '224', '192', '160', or '128',      found '%s' architecture '%s'""",               size_string, architecture)           return none         if len(parts) == 3:           is_quantized = false         else:           if parts[3] != 'quantized':             tf.logging.error(                 "couldn't understand architecture suffix '%s' '%s'", parts[3],                 architecture)             return none           is_quantized = true         data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'         data_url += version_string + '_' + size_string + '_frozen.tgz'         bottleneck_tensor_name = 'mobilenetv1/predictions/reshape:0'         bottleneck_tensor_size = 1001         input_width = int(size_string)         input_height = int(size_string)         input_depth = 3         resized_input_tensor_name = 'input:0'         if is_quantized:           model_base_name = 'quantized_graph.pb'         else:           model_base_name = 'frozen_graph.pb'         model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string         model_file_name = os.path.join(model_dir_name, model_base_name)         input_mean = 127.5         input_std = 127.5       else:         tf.logging.error("couldn't understand architecture name '%s'", architecture)         raise valueerror('unknown architecture', architecture)        return {           'data_url': data_url,           'bottleneck_tensor_name': bottleneck_tensor_name,           'bottleneck_tensor_size': bottleneck_tensor_size,           'input_width': input_width,           'input_height': input_height,           'input_depth': input_depth,           'resized_input_tensor_name': resized_input_tensor_name,           'model_file_name': model_file_name,           'input_mean': input_mean,           'input_std': input_std,       } 

the compelete code available here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py


Comments

Popular posts from this blog

python - Operations inside variables -

Generic Map Parameter java -

arrays - What causes a java.lang.ArrayIndexOutOfBoundsException and how do I prevent it? -