Tensorflow batch norm leading to the imbalance between training loss and validation loss? -


the code snippet bellow code. use queues load training data, , use feed load validation images . along training process, training loss , training accuracy goes right. however, validation phase, validation losses , accuracy weird. validation loss high, , validation accuracy low no matter how many steps run, , random guess. however, when set 'is_training' parameter true instead of false in function load_validate_img_data, validation losses , accuracy goes right . there wrong use of batch_norm?

def inference(inputs,             num_classes=1000,             is_training=true,             dropout_keep_prob=0.5,             reuse = none,             scope='alexnet'):           slim.arg_scope([slim.conv2d, slim.fully_connected],                         normalizer_fn=slim.batch_norm,                         activation_fn=tf.nn.relu,                         biases_initializer=tf.constant_initializer(0.1),                         weights_regularizer=slim.l2_regularizer(weight_decay),                         normalizer_params={'is_training': is_training,                          'decay': 0.95, 'reuse':reuse, 'scope': scope}):              slim.arg_scope([slim.conv2d], padding='same'):                  slim.arg_scope([slim.max_pool2d], padding='valid') :                      tf.variable_scope(scope, [inputs],reuse = reuse) sc:                                net = slim.conv2d(inputs, 32, [3, 3],2, scope='conv1', padding='valid')                             net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')                              net = slim.conv2d(net, 64, [3, 3], scope='conv2')                             net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')                              net = slim.conv2d(net, 128, [2, 2], scope='conv3')                             net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')                              net = slim.conv2d(net, 256, [2, 2], scope='conv4')                             net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')                              net = slim.conv2d(net, 512, [2, 2], scope='conv5')                             net = slim.avg_pool2d(net, [2, 2],  scope='pool5')                                net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6')                              net = slim.conv2d(net, num_classes,[1,1]  ,activation_fn = none, normalizer_fn = none, scope='fc7')                             net = tf.squeeze(net, [1, 2], name='fc8/squeezed')                             end_points = net                              return net, end_points  def get_softmax_loss(logits, labels, name = 'train'):       one_hot_labels = slim.one_hot_encoding(labels, label_num)      softmax_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = one_hot_labels, logits = logits))      vars = tf.trainable_variables()     regularization_loss = tf.add_n([tf.nn.l2_loss(v) v in vars]) * 0.0005       total_loss = softmax_loss + regularization_loss        return total_loss  def get_train_op(loss):       lr_in_use = tf.variable(0.01, trainable=false)     tf.name_scope('lr_update'):                lr_update = tf.assign(lr_in_use, tf.maximum(lr_in_use*0.5, 0.000001))       optimizer = tf.train.momentumoptimizer(lr_in_use, 0.9)       step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=false)     train_op = slim.learning.create_train_op(loss, optimizer, global_step = step)        loss_update = loss     update_ops = tf.get_collection(tf.graphkeys.update_ops)     if update_ops:         updates = tf.group(*update_ops)         loss_update = control_flow_ops.with_dependencies([updates], loss)        return train_op, loss_update, lr_update  def get_train_acc(logits, labels, name = 'train'):      accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(logits, 1), labels), tf.float32))      return accuracy  def load_validate_img_data():        validate_img_root = '~/data/'      img_roots = glob(validate_img_root + '*.bmp')      validate_img = []     validate_label = []     read_count = 0     root in img_roots:          if read_count == 400:             break          label_root = root.split('/')         validate_label.append(label_root[-1][:-4])         validate_img.append(cv2.imread(root))          read_count += 1        validate_img = np.array(validate_img).astype(np.float32)     validate_label = np.array(validate_label).astype(np.int64)       tf.name_scope('validate_input'):         input_imgs = tf.placeholder(tf.float32, shape = (100, original_size[0], original_size[1], channels), name = 'imgs')         input_labels = tf.placeholder(tf.int64, shape = (100), name = 'labels')     transfer_input_imgs = ut._resize_crop_img(input_imgs, resize_to, resize_to, process_type = 'validate')        logits, out_data = face_train.inference(transfer_input_imgs,  num_classes=label_num, is_training = false, reuse = true)       validate_accuracy = get_train_acc(logits, input_labels, name = 'validate')     validate_loss = get_softmax_loss(logits, input_labels, name = 'validate')        return validate_img, validate_label, input_imgs, input_labels, validate_accuracy, validate_loss  tf.graph().as_default():        images, labels = ut._load_batch_t(data_dir, original_size, channels, batch_size, resize_to, resize_to)       logits= face_train.inference(images,  num_classes=label_num)       accuracy = get_train_acc(logits, labels)      total_loss = get_softmax_loss(logits, labels)     train_op, loss_update, lr_update = get_train_op(total_loss)         validate_img, validate_label, img_placeholer,label_placeholder, validate_accuracy, validate_loss = load_validate_img_data()        tf.session() sess:            sess.run(tf.global_variables_initializer())         sess.run(tf.local_variables_initializer())           saver = tf.train.saver(tf.global_variables(), max_to_keep=10000)              coord = tf.train.coordinator()           threads = tf.train.start_queue_runners(coord=coord)           total_step = 0         epoc_step = int(sample_num/batch_size)         epoc in range(epoc_num):             step in range(epoc_step):                    _ = sess.run([train_op])                   if total_step % 20 == 0:                     loss, train_ac  =sess.run([loss_update, accuracy])                     print ('epoc : %d, step : %d, train_loss : %.2f, train_acc: %.3f' %(epoc, step, loss, train_ac))                     if total_step % 200 == 0:                     all_va_acc = 0                     all_va_loss = 0                      in range(4):                          feed_dict = {img_placeholer: validate_img[i*100 : (i+1)*100],  \                        label_placeholder: validate_label[i*100 : (i+1)*100]}                          va_acc, va_loss, summary_val= sess.run([validate_accuracy, validate_loss, merged_val ], feed_dict = feed_dict)                         all_va_acc += va_acc                         all_va_loss += va_loss                       print ('validate_accuracy: %.2f,  validate_loss: %.2f' % (all_va_acc/4.0, all_va_loss/4.0))                    total_step += 1             coord.request_stop()         coord.join(threads) 

during inference, batch norm moving average mean , moving average variance used, need set parameter is_training false.

def inference(inputs,         num_classes=1000,         is_training=false,         dropout_keep_prob=0.5,         reuse = none,         scope='alexnet'): 

Comments

Popular posts from this blog

ubuntu - PHP script to find files of certain extensions in a directory, returns populated array when run in browser, but empty array when run from terminal -

php - How can i create a user dashboard -

javascript - How to detect toggling of the fullscreen-toolbar in jQuery Mobile? -