Tensorflow batch norm leading to the imbalance between training loss and validation loss? -
the code snippet bellow code. use queues load training data, , use feed load validation images . along training process, training loss , training accuracy goes right. however, validation phase, validation losses , accuracy weird. validation loss high, , validation accuracy low no matter how many steps run, , random guess. however, when set 'is_training' parameter true instead of false in function load_validate_img_data, validation losses , accuracy goes right . there wrong use of batch_norm?
def inference(inputs, num_classes=1000, is_training=true, dropout_keep_prob=0.5, reuse = none, scope='alexnet'): slim.arg_scope([slim.conv2d, slim.fully_connected], normalizer_fn=slim.batch_norm, activation_fn=tf.nn.relu, biases_initializer=tf.constant_initializer(0.1), weights_regularizer=slim.l2_regularizer(weight_decay), normalizer_params={'is_training': is_training, 'decay': 0.95, 'reuse':reuse, 'scope': scope}): slim.arg_scope([slim.conv2d], padding='same'): slim.arg_scope([slim.max_pool2d], padding='valid') : tf.variable_scope(scope, [inputs],reuse = reuse) sc: net = slim.conv2d(inputs, 32, [3, 3],2, scope='conv1', padding='valid') net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') net = slim.conv2d(net, 64, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') net = slim.conv2d(net, 128, [2, 2], scope='conv3') net = slim.max_pool2d(net, [2, 2], 2, scope='pool3') net = slim.conv2d(net, 256, [2, 2], scope='conv4') net = slim.max_pool2d(net, [2, 2], 2, scope='pool4') net = slim.conv2d(net, 512, [2, 2], scope='conv5') net = slim.avg_pool2d(net, [2, 2], scope='pool5') net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6') net = slim.conv2d(net, num_classes,[1,1] ,activation_fn = none, normalizer_fn = none, scope='fc7') net = tf.squeeze(net, [1, 2], name='fc8/squeezed') end_points = net return net, end_points def get_softmax_loss(logits, labels, name = 'train'): one_hot_labels = slim.one_hot_encoding(labels, label_num) softmax_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = one_hot_labels, logits = logits)) vars = tf.trainable_variables() regularization_loss = tf.add_n([tf.nn.l2_loss(v) v in vars]) * 0.0005 total_loss = softmax_loss + regularization_loss return total_loss def get_train_op(loss): lr_in_use = tf.variable(0.01, trainable=false) tf.name_scope('lr_update'): lr_update = tf.assign(lr_in_use, tf.maximum(lr_in_use*0.5, 0.000001)) optimizer = tf.train.momentumoptimizer(lr_in_use, 0.9) step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=false) train_op = slim.learning.create_train_op(loss, optimizer, global_step = step) loss_update = loss update_ops = tf.get_collection(tf.graphkeys.update_ops) if update_ops: updates = tf.group(*update_ops) loss_update = control_flow_ops.with_dependencies([updates], loss) return train_op, loss_update, lr_update def get_train_acc(logits, labels, name = 'train'): accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(logits, 1), labels), tf.float32)) return accuracy def load_validate_img_data(): validate_img_root = '~/data/' img_roots = glob(validate_img_root + '*.bmp') validate_img = [] validate_label = [] read_count = 0 root in img_roots: if read_count == 400: break label_root = root.split('/') validate_label.append(label_root[-1][:-4]) validate_img.append(cv2.imread(root)) read_count += 1 validate_img = np.array(validate_img).astype(np.float32) validate_label = np.array(validate_label).astype(np.int64) tf.name_scope('validate_input'): input_imgs = tf.placeholder(tf.float32, shape = (100, original_size[0], original_size[1], channels), name = 'imgs') input_labels = tf.placeholder(tf.int64, shape = (100), name = 'labels') transfer_input_imgs = ut._resize_crop_img(input_imgs, resize_to, resize_to, process_type = 'validate') logits, out_data = face_train.inference(transfer_input_imgs, num_classes=label_num, is_training = false, reuse = true) validate_accuracy = get_train_acc(logits, input_labels, name = 'validate') validate_loss = get_softmax_loss(logits, input_labels, name = 'validate') return validate_img, validate_label, input_imgs, input_labels, validate_accuracy, validate_loss tf.graph().as_default(): images, labels = ut._load_batch_t(data_dir, original_size, channels, batch_size, resize_to, resize_to) logits= face_train.inference(images, num_classes=label_num) accuracy = get_train_acc(logits, labels) total_loss = get_softmax_loss(logits, labels) train_op, loss_update, lr_update = get_train_op(total_loss) validate_img, validate_label, img_placeholer,label_placeholder, validate_accuracy, validate_loss = load_validate_img_data() tf.session() sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.saver(tf.global_variables(), max_to_keep=10000) coord = tf.train.coordinator() threads = tf.train.start_queue_runners(coord=coord) total_step = 0 epoc_step = int(sample_num/batch_size) epoc in range(epoc_num): step in range(epoc_step): _ = sess.run([train_op]) if total_step % 20 == 0: loss, train_ac =sess.run([loss_update, accuracy]) print ('epoc : %d, step : %d, train_loss : %.2f, train_acc: %.3f' %(epoc, step, loss, train_ac)) if total_step % 200 == 0: all_va_acc = 0 all_va_loss = 0 in range(4): feed_dict = {img_placeholer: validate_img[i*100 : (i+1)*100], \ label_placeholder: validate_label[i*100 : (i+1)*100]} va_acc, va_loss, summary_val= sess.run([validate_accuracy, validate_loss, merged_val ], feed_dict = feed_dict) all_va_acc += va_acc all_va_loss += va_loss print ('validate_accuracy: %.2f, validate_loss: %.2f' % (all_va_acc/4.0, all_va_loss/4.0)) total_step += 1 coord.request_stop() coord.join(threads)
during inference, batch norm moving average mean
, moving average variance
used, need set parameter is_training
false.
def inference(inputs, num_classes=1000, is_training=false, dropout_keep_prob=0.5, reuse = none, scope='alexnet'):
Comments
Post a Comment