目录
  • 前言
  • 一、什么是语义分割
  • 二、unet
    • 1.基本原理
    • 2.mini_unet
    • 3. mobilenet_unet
    • 4.数据加载部分
  • 参考

    前言

    最近由于在寻找方向上迷失自我,准备了解更多的计算机视觉任务重的模型。看到语义分割任务重unet一个有意思的模型,我准备来复现一下它。

    一、什么是语义分割

    语义分割任务,如下图所示:

    简而言之,语义分割任务就是将图片中的不同类别,用不同的颜色标记出来,每一个类别使用一种颜色。常用于医学图像,卫星图像任务。

    那如何做到将像素点上色呢?

    其实语义分割的输出和图像分类网络类似,图像分类类别数是一个一维的one hot 矩阵。例如:三分类的[0,1,0]。

    语义分割任务最后的输出特征图 是一个三维结构,大小与原图类似,通道数就是类别数。 如下图(图片来源于知乎)所示:

    其中通道数是类别数,每个通道所标记的像素点,是该类别在图像中的位置,最后通过argmax 取每个通道有用像素 合成一张图像,用不同颜色表示其类别位置。 语义分割任务其实也是分类任务中的一种,他不过是对每一个像素点进行细分,找到每一个像素点所述的类别。 这就是语义分割任务啦~

    下面我们来复现 unet 模型

    二、unet

    1.基本原理

    什么是unet,它的网络结构如下图所示:

    整个网络是一个“u” 的形状,unet 网络可以分成两部分,上图红色方框中是特征提取部分,和其他卷积神经网络一样,都是通过堆叠卷积提取图像特征,通过池化来压缩特征图。蓝色方框中为图像还原部分(这样称它可能不太专业,大家理解就好),通过上采样和卷积来来将压缩的图像进行还原。特征提取部分可以使用优秀的网络,例如:resnet50,vgg等。

    注意:由于 resnet50和vgg 网络太大。本文将使用mobilenet 作为主干特征提取网络。为了方便理解unet,本文将使用自己搭建的一个mini_unet 去帮祝大家理解。为了方便计算,复现过程会把压缩后的特征图上采样和输入的特征图一样大小。

    代码github地址: 一直上不去

    先上传到码云: https://gitee.com/boss-jian/unet

    2.mini_unet

    mini_unet 是搭建来帮助大家理解语义分割的网络流程,并不能作为一个优秀的模型完成语义分割任务,来看一下代码的实现:

    from keras.layers import input,conv2d,dropout,maxpooling2d,concatenate,upsampling2d
    from numpy import pad
    from keras.models import model
    def unet_mini(n_classes=21,input_shape=(224,224,3)):
    
        img_input = input(shape=input_shape)
    
       
        #------------------------------------------------------
        # #encoder 部分
        #224,224,3 - > 112,112,32
        conv1 = conv2d(32,(3,3),activation='relu',padding='same')(img_input)
        conv1 = dropout(0.2)(conv1)
        conv1 = conv2d(32,(3,3),activation='relu',padding='same')(conv1)
        pool1 = maxpooling2d((2,2),strides=2)(conv1)
    
    
        #112,112,32 -> 56,56,64
        conv2 = conv2d(64,(3,3),activation='relu',padding='same')(pool1)
        conv2 = dropout(0.2)(conv2)
        conv2 = conv2d(64,(3,3),activation='relu',padding='same')(conv2)
        pool2 = maxpooling2d((2,2),strides=2)(conv2)
    
    
        #56,56,64 -> 56,56,128
        conv3 = conv2d(128,(3,3),activation='relu',padding='same')(pool2)
        conv3 = dropout(0.2)(conv3)
        conv3 = conv2d(128,(3,3),activation='relu',padding='same')(conv3)
    
        #-------------------------------------------------
        # decoder 部分
        #56,56,128 -> 112,112,64 
        up1 = upsampling2d(2)(conv3)
        #112,112,64 -> 112,112,64+128
        up1 = concatenate(axis=-1)([up1,conv2])
        #  #112,112,192 -> 112,112,64
        conv4  = conv2d(64,(3,3),activation='relu',padding='same')(up1)
        conv4  = dropout(0.2)(conv4)
        conv4  = conv2d(64,(3,3),activation='relu',padding='same')(conv4)
    
        #112,112,64 - >224,224,64
        up2 = upsampling2d(2)(conv4)
        #224,224,64 -> 224,224,64+32
        up2 = concatenate(axis=-1)([up2,conv1])
        # 224,224,96 -> 224,224,32
        conv5 =  conv2d(32,(3,3),activation='relu',padding='same')(up2)
        conv5  = dropout(0.2)(conv5)
        conv5  = conv2d(32,(3,3),activation='relu',padding='same')(conv5)
        
        o = conv2d(n_classes,1,padding='same')(conv5)
    
        return model(img_input,o,name="unet_mini")
    
    if __name__=="__main__":
        model = unet_mini()
        model.summary()
    

    mini_unet 通过encoder 部分将 224x224x3的图像 变成 112x112x64 的特征图,再通过 上采样方法将特征图放大到 224x224x32。最后通过卷积:

    o = conv2d(n_classes,1,padding='same')(conv5)

    将特征图的通道数调节成和类别数一样。

    3. mobilenet_unet

    mobilenet_unet 是使用mobinet 作为主干特征提取网络,并且加载预训练权重来提升特征提取的能力。decoder 的还原部分和上面一致,下面是mobilenet_unet 的网络结构:

    from keras.models import *
    from keras.layers import *
    import keras.backend as k
    import keras
    from tensorflow.python.keras.backend import shape
    
    image_ordering =  "channels_last"# channel last
    def relu6(x):
        return k.relu(x, max_value=6)
    
    
    def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
       
        channel_axis = 1 if image_ordering == 'channels_first' else -1
        filters = int(filters * alpha)
        x = zeropadding2d(padding=(1, 1), name='conv1_pad',
                          data_format=image_ordering)(inputs)
        x = conv2d(filters, kernel, data_format=image_ordering,
                   padding='valid',
                   use_bias=false,
                   strides=strides,
                   name='conv1')(x)
        x = batchnormalization(axis=channel_axis, name='conv1_bn')(x)
        return activation(relu6, name='conv1_relu')(x)
    
    
    def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
                              depth_multiplier=1, strides=(1, 1), block_id=1):
    
        channel_axis = 1 if image_ordering == 'channels_first' else -1
        pointwise_conv_filters = int(pointwise_conv_filters * alpha)
    
        x = zeropadding2d((1, 1), data_format=image_ordering,
                          name='conv_pad_%d' % block_id)(inputs)
        x = depthwiseconv2d((3, 3), data_format=image_ordering,
                            padding='valid',
                            depth_multiplier=depth_multiplier,
                            strides=strides,
                            use_bias=false,
                            name='conv_dw_%d' % block_id)(x)
        x = batchnormalization(
            axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
        x = activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
    
        x = conv2d(pointwise_conv_filters, (1, 1), data_format=image_ordering,
                   padding='same',
                   use_bias=false,
                   strides=(1, 1),
                   name='conv_pw_%d' % block_id)(x)
        x = batchnormalization(axis=channel_axis,
                               name='conv_pw_%d_bn' % block_id)(x)
        return activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
    
    def get_mobilnet_eocoder(input_shape=(224,224,3),weights_path=""):
    
        # 必须是32 的倍数
        assert input_shape[0] % 32 == 0
        assert input_shape[1] % 32 == 0
    
        alpha = 1.0
        depth_multiplier = 1
    
        img_input = input(shape=input_shape)
        #(none, 224, 224, 3) ->(none, 112, 112, 64)
        x = _conv_block(img_input, 32, alpha, strides=(2, 2))
        x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)
        f1 = x
     
        #(none, 112, 112, 64) -> (none, 56, 56, 128)
        x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
                                  strides=(2, 2), block_id=2)
        x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)
        f2 = x
       #(none, 56, 56, 128) -> (none, 28, 28, 256)
        x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
                                  strides=(2, 2), block_id=4)
        x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
        f3 = x
        # (none, 28, 28, 256) ->  (none, 14, 14, 512)
        x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
                                  strides=(2, 2), block_id=6)
        x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
        x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
        x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
        x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
        x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)
        f4 = x
        # (none, 14, 14, 512) -> (none, 7, 7, 1024)
        x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
                                  strides=(2, 2), block_id=12)
        x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)
        f5 = x
        # 加载预训练模型
        if weights_path!="":
            model(img_input, x).load_weights(weights_path, by_name=true, skip_mismatch=true)
        # f1: (none, 112, 112, 64)
        # f2: (none, 56, 56, 128)
        # f3: (none, 28, 28, 256)
        # f4: (none, 14, 14, 512)
        # f5: (none, 7, 7, 1024)
        return img_input, [f1, f2, f3, f4, f5]
    
    
    def mobilenet_unet(num_classes=2,input_shape=(224,224,3)):
        
        #encoder 
        img_input,levels = get_mobilnet_eocoder(input_shape=input_shape,weights_path="model_data\mobilenet_1_0_224_tf_no_top.h5")
    
        [f1, f2, f3, f4, f5] = levels
    
        # f1: (none, 112, 112, 64)
        # f2: (none, 56, 56, 128)
        # f3: (none, 28, 28, 256)
        # f4: (none, 14, 14, 512)
        # f5: (none, 7, 7, 1024)
    
        #decoder
        #(none, 14, 14, 512) - > (none, 14, 14, 512)
        o = f4
        o = zeropadding2d()(o)
        o = conv2d(512, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering)(o)
        o = batchnormalization()(o)
    
        #(none, 14, 14, 512) ->(none,28,28,256)
        o = upsampling2d(2)(o)
        o = concatenate(axis=-1)([o,f3])
        o = zeropadding2d()(o)
        o = conv2d(256, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering)(o)
        o = batchnormalization()(o)
        # none,28,28,256)->(none,56,56,128)
        o = upsampling2d(2)(o)
        o = concatenate(axis=-1)([o,f2])
        o = zeropadding2d()(o)
        o = conv2d(128, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering)(o)
        o = batchnormalization()(o)
        #(none,56,56,128) ->(none,112,112,64)
        o = upsampling2d(2)(o)
        o = concatenate(axis=-1)([o,f1])
        o = zeropadding2d()(o)
        o = conv2d(128, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering)(o)
        o = batchnormalization()(o)
        #(none,112,112,64) -> (none,112,112,num_classes)
    
        # 再上采样 让输入和出处图片大小一致
        o = upsampling2d(2)(o)
        o = zeropadding2d()(o)
        o = conv2d(64, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering)(o)
        o = batchnormalization()(o)
    
        o = conv2d(num_classes, (3, 3), padding='same',
                   data_format=image_ordering)(o)
    
        return model(img_input,o)
    
    if __name__=="__main__":
        mobilenet_unet(input_shape=(512,512,3)).summary()
    
    
    

    特征图的大小变化,以及代码含义都已经注释在代码里了。大家仔细阅读吧

    4.数据加载部分

    import math
    import os
    from random import shuffle
    
    import cv2
    import keras
    import numpy as np
    from pil import image
    #-------------------------------
    # 将图片转换为 rgb
    #------------------------------
    def cvtcolor(image):
        if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
            return image 
        else:
            image = image.convert('rgb')
            return image 
    #-------------------------------
    # 图片归一化 0~1
    #------------------------------
    def preprocess_input(image):
        image = image / 127.5 - 1
        return image
    #---------------------------------------------------
    #   对输入图像进行resize
    #---------------------------------------------------
    def resize_image(image, size):
        iw, ih  = image.size
        w, h    = size
    
        scale   = min(w/iw, h/ih)
        nw      = int(iw*scale)
        nh      = int(ih*scale)
    
        image   = image.resize((nw,nh), image.bicubic)
        new_image = image.new('rgb', size, (128,128,128))
        new_image.paste(image, ((w-nw)//2, (h-nh)//2))
    
        return new_image, nw, nh
    
    
    class unetdataset(keras.utils.sequence):
        def __init__(self, annotation_lines, input_shape, batch_size, num_classes, train, dataset_path):
            self.annotation_lines   = annotation_lines
            self.length             = len(self.annotation_lines)
            self.input_shape        = input_shape
            self.batch_size         = batch_size
            self.num_classes        = num_classes
            self.train              = train
            self.dataset_path       = dataset_path
    
        def __len__(self):
            return math.ceil(len(self.annotation_lines) / float(self.batch_size))
    
        def __getitem__(self, index):
            #图片和标签、
            images  = []
            targets = []
            # 读取一个batchsize
            for i in range(index*self.batch_size,(index+1)*self.batch_size):
                #判断 i 越界情况
                i = i%self.length
                name = self.annotation_lines[i].split()[0]
                # 从路径中读取图像 jpg 表示图片,png 表示标签
                jpg = image.open(os.path.join(os.path.join(self.dataset_path,'images'),name+'.png'))
                png = image.open(os.path.join(os.path.join(self.dataset_path,'labels'),name+'.png'))
    
                #-------------------
                # 数据增强  和 归一化
                #-------------------
                jpg,png = self.get_random_data(jpg,png,self.input_shape,random=self.train)
                jpg = preprocess_input(np.array(jpg,np.float64))
                png = np.array(png)
    
                #-----------------------------------
                # 医学图像中 描绘出的是细胞边缘 
                #  将小于 127.5的像素点 作为目标 像素点
                #------------------------------------
    
                seg_labels = np.zeros_like(png)
                seg_labels[png<=127.5] = 1
                #--------------------------------
                # 转化为 one hot 标签
                # -------------------------
                seg_labels  = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])]
                seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
    
                images.append(jpg)
                targets.append(seg_labels)
    
            images  = np.array(images)
            targets = np.array(targets)
            return images, targets
    
        def rand(self, a=0, b=1):
            return np.random.rand() * (b - a) + a
    
        def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=true):
            image = cvtcolor(image)
            label = image.fromarray(np.array(label))
            h, w = input_shape
    
            if not random:
                iw, ih  = image.size
                scale   = min(w/iw, h/ih)
                nw      = int(iw*scale)
                nh      = int(ih*scale)
    
                image       = image.resize((nw,nh), image.bicubic)
                new_image   = image.new('rgb', [w, h], (128,128,128))
                new_image.paste(image, ((w-nw)//2, (h-nh)//2))
    
                label       = label.resize((nw,nh), image.nearest)
                new_label   = image.new('l', [w, h], (0))
                new_label.paste(label, ((w-nw)//2, (h-nh)//2))
                return new_image, new_label
    
            # resize image
            rand_jit1 = self.rand(1-jitter,1+jitter)
            rand_jit2 = self.rand(1-jitter,1+jitter)
            new_ar = w/h * rand_jit1/rand_jit2
    
            scale = self.rand(0.25, 2)
            if new_ar < 1:
                nh = int(scale*h)
                nw = int(nh*new_ar)
            else:
                nw = int(scale*w)
                nh = int(nw/new_ar)
    
            image = image.resize((nw,nh), image.bicubic)
            label = label.resize((nw,nh), image.nearest)
            
            flip = self.rand()<.5
            if flip: 
                image = image.transpose(image.flip_left_right)
                label = label.transpose(image.flip_left_right)
            
            # place image
            dx = int(self.rand(0, w-nw))
            dy = int(self.rand(0, h-nh))
            new_image = image.new('rgb', (w,h), (128,128,128))
            new_label = image.new('l', (w,h), (0))
            new_image.paste(image, (dx, dy))
            new_label.paste(label, (dx, dy))
            image = new_image
            label = new_label
    
            # distort image
            hue = self.rand(-hue, hue)
            sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
            val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
            x = cv2.cvtcolor(np.array(image,np.float32)/255, cv2.color_rgb2hsv)
            x[..., 0] += hue*360
            x[..., 0][x[..., 0]>1] -= 1
            x[..., 0][x[..., 0]<0] += 1
            x[..., 1] *= sat
            x[..., 2] *= val
            x[x[:,:, 0]>360, 0] = 360
            x[:, :, 1:][x[:, :, 1:]>1] = 1
            x[x<0] = 0
            image_data = cv2.cvtcolor(x, cv2.color_hsv2rgb)*255
            return image_data,label
    
        def on_epoch_begin(self):
            shuffle(self.annotation_lines)
    
    

    训练过程代码:

    import numpy as np
    from  tensorflow.python.keras.callbacks import earlystopping, modelcheckpoint, tensorboard
    from keras.optimizers import adam
    import os
    from unet_mini import unet_mini
    from mobilnet_unet import mobilenet_unet
    from callbacks import exponentdecayscheduler,losshistory
    from keras import backend as k
    from keras import backend 
    from data_loader import unetdataset
    #--------------------------------------
    # 交叉熵损失函数 cls_weights 类别的权重
    #-------------------------------------
    def ce(cls_weights):
        cls_weights = np.reshape(cls_weights, [1, 1, 1, -1])
        def _ce(y_true, y_pred):
            y_pred = k.clip(y_pred, k.epsilon(), 1.0 - k.epsilon())
    
            ce_loss = - y_true[...,:-1] * k.log(y_pred) * cls_weights
            ce_loss = k.mean(k.sum(ce_loss, axis = -1))
            # dice_loss = tf.print(ce_loss, [ce_loss])
            return ce_loss
        return _ce
    def f_score(beta=1, smooth = 1e-5, threhold = 0.5):
        def _f_score(y_true, y_pred):
            y_pred = backend.greater(y_pred, threhold)
            y_pred = backend.cast(y_pred, backend.floatx())
    
            tp = backend.sum(y_true[...,:-1] * y_pred, axis=[0,1,2])
            fp = backend.sum(y_pred         , axis=[0,1,2]) - tp
            fn = backend.sum(y_true[...,:-1], axis=[0,1,2]) - tp
    
            score = ((1 + beta ** 2) * tp + smooth) \
                    / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
            return score
        return _f_score
    
    def train():
        #-------------------------
        # 细胞图像 分为细胞壁 和其他
        # 初始化 参数
        #-------------------------
        num_classes  = 2 
    
        input_shape = (512,512,3)
        # 从第几个epoch 继续训练
        
        batch_size = 4
    
        learn_rate  = 1e-4
    
        start_epoch = 0
        end_epoch = 100
        num_workers = 4
    
        dataset_path = 'medical_datasets'
    
        model = mobilenet_unet(num_classes,input_shape=input_shape)
    
        model.summary()
    
        # 读取数据图片的路劲
        with open(os.path.join(dataset_path, "imagesets/segmentation/train.txt"),"r") as f:
            train_lines = f.readlines()
    
        
        logging         = tensorboard(log_dir = 'logs/')
        checkpoint      = modelcheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}.h5',
                            monitor = 'loss', save_weights_only = true, save_best_only = false, period = 1)
        reduce_lr       = exponentdecayscheduler(decay_rate = 0.96, verbose = 1)
        early_stopping  = earlystopping(monitor='loss', min_delta=0, patience=10, verbose=1)
        loss_history    = losshistory('logs/', val_loss_flag = false)
    
        epoch_step      = len(train_lines) // batch_size
        cls_weights     = np.ones([num_classes], np.float32)
        loss = ce(cls_weights)
        model.compile(loss = loss,
                    optimizer = adam(lr=learn_rate),
                    metrics = [f_score()])
    
        train_dataloader    = unetdataset(train_lines, input_shape[:2], batch_size, num_classes, true, dataset_path)
        
        
        print('train on {} samples, with batch size {}.'.format(len(train_lines), batch_size))
        model.fit_generator(
                generator           = train_dataloader,
                steps_per_epoch     = epoch_step,
                epochs              = end_epoch,
                initial_epoch       = start_epoch,
                # use_multiprocessing = true if num_workers > 1 else false,
                workers             = num_workers,
                callbacks           = [logging, checkpoint, early_stopping,reduce_lr,loss_history]
            )
    
    if __name__=="__main__":
        train()
    
    

    最后的预测结果:

    完整的代大家感兴趣可以去github下载下来再看,代码比较多,全部贴出来博客显得太长了。

    这就是简单的语义分割任务啦。

    参考

    https://github.com/bubbliiiing/unet-keras

    https://github.com/divamgupta/image-segmentation-keras 

    以上就是python深度学习之unet 语义分割模型(keras)的详细内容,更多关于python unet 语义分割模型的资料请关注www.887551.com其它相关文章!