def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:'+str(GPU_INDEX)):
            pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
            """ pointclouds_pl= (16,1024,3) label_pl = (16,) """
            is_training_pl = tf.placeholder(tf.bool, shape=(16))
            """ 注意全局step = batch参数以使其最小化;优化器增加batch参数在你每一次训练的时候 """ 
            batch = tf.get_variable('batch', [],
                                    initializer=tf.constant_initializer(0), trainable=False)
            bn_decay = get_bn_decay(batch)  #bn_decay批标准化衰减
            tf.summary.scalar('bn_decay', bn_decay)

            # Get model and loss 
            pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
            """ 模型训练得到pred = (16,40),每个实例分为40个分数, end_points = (16,1024,3)为原始点云数据 """
            MODEL.get_loss(pred, labels_pl, end_points)
            """ 计算损失 """
            losses = tf.get_collection('losses') 
            """ 返回losses列表 """
            total_loss = tf.add_n(losses, name='total_loss')
            """ 将losses列表中所有值全部相加 """
            tf.summary.scalar('total_loss', total_loss)
            for l in losses + [total_loss]:
                tf.summary.scalar(l.op.name, l)

            correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
            """ 每一次训练后的预测分类正确的个数"""
            accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
            """" 单个批次训练后的正确率"""
            tf.summary.scalar('accuracy', accuracy)

            print ("--- Get training operator")
            """ 获得训练操作 """
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar('learning_rate', learning_rate)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            train_op = optimizer.minimize(total_loss, global_step=batch)
            """ 梯度优化最小化"""
            
            # Add ops to save and restore all the variables.
            saver = tf.train.Saver()

PointNet++的训练函数,进行构造训练网络的图结构,输入point和label占位符,进行训练模型–>计算损失–>优化器优化–>模型保存。同时每个批次训练完后,计算分类正确的实例个数和分类正确率。

# Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        """ 创建会话 """
        # Add summary writers
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'), sess.graph)
       
        # Init variables
        init = tf.global_variables_initializer()
        sess.run(init)
        """ 变量初始化"""
ops = {'pointclouds_pl': pointclouds_pl,
               'labels_pl': labels_pl,
               'is_training_pl': is_training_pl,
               'pred': pred,
               'loss': total_loss,
               'train_op': train_op,
               'merged': merged,
               'step': batch,
               'end_points': end_points}

        best_acc = -1
        for epoch in range(MAX_EPOCH):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()
             
            train_one_epoch(sess, ops, train_writer)
            eval_one_epoch(sess, ops, test_writer)

            # Save the variables to disk.
            if epoch % 10 == 0:
                save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                log_string("Model saved in file: %s" % save_path)

字典参数赋值,开始进行每一次训练,每次训练进行train_one_epoch,和eval_one_epoch操作,每训练10次,保存一次模型。

代码不同之处

def train_one_epoch(sess, ops, train_writer):
'''''''''''''''''''''''''''''''''' 参数变量初始化'''''''''''''''''''''''''''''''''''''''''''''
  while TRAIN_DATASET.has_next_batch():
        batch_data, batch_label = TRAIN_DATASET.next_batch(augment=True)
"""上面步骤进行数据预处理,之后进行训练网络等一系列操作,数据处理都写进modelnet_h5_dataset.py"""

train_one_epoch和eval_one_epoch与PointNet中的大同小异,唯一变化的是作者将数据处理部分写进了modelnet_h5_dataset.py中,创建了一个ModelNetH5Dataset()类。

ModelNetH5Dataset():
该类初始化
1.首先h5_files = getDataFiles(self.list_filename),将train_file中的h5文件名每行读出。
2.定义reset()打乱 h5文件

 def reset(self):
        ''' reset order of h5 files '''
        self.file_idxs = np.arange(0, len(self.h5_files))
        """ 创建索引并打乱索引 """
        if self.shuffle: np.random.shuffle(self.file_idxs)
        self.current_data = None
        self.current_label = None
        self.current_file_idx = 0
        self.batch_idx = 0

has_next_batch()

    def has_next_batch(self):
        # TODO: add backend thread to load data
        if (self.current_data is None) or (not self._has_next_batch_in_file()):
            if self.current_file_idx >= len(self.h5_files):
                return False
            self._load_data_file(self._get_data_filename())
            self.batch_idx = 0
            self.current_file_idx += 1
        return self._has_next_batch_in_file()

首先判断当前文件索引是否为0,或者是否文件已经遍历完 (
读取文件中的数据,每次从中读出一个h5文件的数据
self._load_data_file(self._get_data_filename())

self._get_data_filename(),根据之前创建的打乱的文件索引读出文件名 (读文件打乱
self._load_data_file()利用h5py.File(h5_ilename)读出data 和 label 将data 和 label 对应打乱(数据打乱)返回 current_data 和 current_label

.next_batch()

从单个h5文件中提取,按照批次尺寸提取出相应尺寸的物体点云数据用于训练


    def next_batch(self, augment=False):
        ''' returned dimension may be smaller than self.batch_size '''
        start_idx = self.batch_idx * self.batch_size
        end_idx = min((self.batch_idx+1) * self.batch_size, self.current_data.shape[0])
        """从h5文件中分批次提取数据,设置初始索引和结束索引 """
        bsize = end_idx - start_idx
        batch_label = np.zeros((bsize), dtype=np.int32)
        data_batch = self.current_data[start_idx:end_idx, 0:self.npoints, :].copy()
        """ 将批次数据从全部数据中提取出来放进data_batch"""
        label_batch = self.current_label[start_idx:end_idx].copy()
        """ 提取出对应的label放进label_batch中"""
        self.batch_idx += 1
        if augment: data_batch = self._augment_batch_data(data_batch)
        """扩充数据,增加旋转和扰动的点云数据"""
        return data_batch, label_batch 

本文地址:https://blog.csdn.net/CSDNcylinux/article/details/107149540