Pytorch 图像分割2-IDRiD糖网并发症分割

前言

大概有一年的时间了,接触这个糖尿病视网膜病变相关的医学影像识别,今天来用Pytorch一起实现数据集中四种并发症的语义分割。

数据集

首先最麻烦的地方一定是数据集,我们的数据集分为6个文件夹。[‘MA’,’EX’,’HE’,’SE’,’ApparentRetinopathy’,’NoApparentRetinopathy’]。其中前4个是4种病变特征[‘微血管瘤Microaneurysms’,’硬性渗出Hard Exudates’,’视网膜出血Haemorrhages’,’软性渗出Soft Exudates’]的mask,’ApparentRetinopathy’文件夹中存了原始眼底图片。最后一个文件夹’NoApparentRetinopathy’,包含了89张没病的眼底图片。
图片尺寸4288×2848。

这里我们通过Pytorch的方法,定义自己的数据集。我们将图片切割成512 x 512的小块进行训练,每张图切割成9 x 6 = 54块。不足512×512的补零。

代码

引入必要的库

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import random

定义Dataset

class IDRiD_sub1_dataset(Dataset):

    def __init__(self, root_dir):

        self.task_type_list = ['MA','EX','HE','SE']
        self.root_dir = root_dir
        self.data_idx = [] # {image_dir, mask_dirs,name} mask_dirs is a list(None for NAR images)
        self.data_cache = {'image':None, 'mask':None, 'name':'','index':None} # cache original size image

        image_root = os.path.join(self.root_dir, 'ApparentRetinopathy')
        image_NAR_root =  os.path.join(self.root_dir, 'NoApparentRetinopathy')

        # get file index data_idx
        # AR images
        for filename in os.listdir(image_root):
            image_dir = os.path.join(image_root, filename)
            mask_dirs = {task_type: None for task_type in self.task_type_list} # 先都设为None

            for task_type in self.task_type_list:
                m_dir = os.path.join(self.root_dir, task_type,filename[:-4] + '_' + task_type + '.tif')
                if os.path.isfile(m_dir):
                    mask_dirs[task_type] = m_dir
            name = filename[:-4]
            self.data_idx.append((image_dir, mask_dirs,name))

        # NAR images
        for filename in os.listdir(image_NAR_root):
            image_dir = os.path.join(image_NAR_root, filename)
            mask_dirs = {task_type: None for task_type in self.task_type_list}
            name = filename[:-4]
            self.data_idx.append((image_dir, mask_dirs, name))

        #Shuffle
        random.shuffle(self.data_idx)

    def __len__(self):
        return len(self.data_idx) * 9 * 6

    def __getitem__(self,idx):

        n = int(idx / (6 * 9)) # image index
        r = int((idx % (6 * 9)) / 9 ) # row
        c = (idx % (6 * 9)) % 9 # col

        # Load the images if it's not in the cache
        if self.data_cache['index'] != n:
            image_dir, mask_dirs, name = self.data_idx[n]
            image = Image.open(image_dir)

            masks = []
            for task_type in self.task_type_list:
                if mask_dirs[task_type] is not None:
                    # AR images
                    mask = Image.open(mask_dirs[task_type])
                    mask = np.array(mask, dtype='float32')
                else:
                    # NAR images
                    w, h = image.size
                    mask = np.zeros((h,w), dtype='float32') #  np,先h 后 w
                masks.append(mask)
            masks = np.array(masks)
            masks = np.pad(masks,((0,0), (0,224), (0,320)), 'constant', constant_values = 0) # padding
            self.data_cache = {'image':image, 'masks':masks,'name':name, 'index':n}

        # crop the image (2848,4288,3) PIL读取图片后(w,h) 对应了(column,row) 转成 numpy是 (h w c),(h ,w) 对应了 (row,column)
        # Image.crop(left, upper, right, lower)
        # image = (4288,2848,3) # 裁超了自动补0 -> (512,512,3)
        image_crop = self.data_cache['image'].crop((c * 512, r*512,c*512 + 512,r * 512 + 512 ))
        #masks = (4, 3072, 4608) -> (4,512,512)
        masks_crop = self.data_cache['masks'][:,r * 512:r * 512 + 512,c * 512:c * 512 + 512]
        image_crop = transforms.ToTensor()(image_crop) # (512,512,3) -> torch.Size([3, 512, 512])
        masks_crop = torch.from_numpy(masks_crop)  # (4,512,512) -> torch.Size([4, 512, 512])
        name = self.data_cache['name'] + '(%2d ,%2d)' % (r , c)
        return image_crop,masks_crop,name



测试

  • 首先测试图像的数目

ApparentRetinopathy 54张
NoApparentRetinopathy 89张

(51 + 89 ) * 9 * 6 = 7722

dataset = IDRiD_sub1_dataset('/xxx/data/train/')
print('dataset length',len(dataset))

输出

dataset length 7722
  • 测试维度
print('dataset sample')

image, mask, name = dataset[random.randint(0, len(dataset) - 1)]
image.shape, mask.shape, name 

输出

dataset sample
(torch.Size([3, 512, 512]), torch.Size([4, 512, 512]), 'IDRiD_35( 4 , 3)')
  • 显示图片和Mask测试

这里选择视网膜出血这个mask较明显的来测试。找到10个有视网膜出血的,同时展示原图和mask。

%load_ext autoreload
%autoreload 2

%matplotlib inline

# import matplotlib.pyplot as plt
# idx = random.randint(0, len(dataset) - 1)
# print(idx)
cnt = 0
for idx in range(len(dataset)):
    image, mask, name = dataset[idx]
    #print(image.shape, mask.shape, name)
    if np.count_nonzero(np.array(mask)[2]) !=0: # 有非0元素,即有mask
        #print(mask.numpy())
        # 非0 元素count
        #print('count nonezero',np.count_nonzero(np.array(mask)))
        #print(mask.numpy().ravel()[np.flatnonzero(mask.numpy())])
        #print(np.nonzero(mask.numpy()))
        plt.imshow(transforms.ToPILImage()(image))
        plt.show()
        plt.imshow(mask.numpy()[2]) # 2 对应了视网膜出血,这个mask比较多
        plt.show()
        cnt += 1
        if cnt == 10:
            break
  • 测试读取速度
from torch.utils.data import DataLoader
import time

t = time.time()
dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=4)
for data in dataloader:
    pass
print('%ds' % (time.time() - t))

85s.

Model

Model 不是重点,主要因为还没到自己改模型的阶段,这里直接拿来用了,也不分析了。

'''
The code is modified from https://github.com/ZijunDeng/pytorch-semantic-segmentation
'''
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

# many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py
class _GlobalConvModule(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size):
        super(_GlobalConvModule, self).__init__()
        pad0 = int((kernel_size[0] - 1) / 2)
        pad1 = int((kernel_size[1] - 1) / 2)
        # kernel size had better be odd number so as to avoid alignment error
        super(_GlobalConvModule, self).__init__()
        self.conv_l1 = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size[0], 1),
                                 padding=(pad0, 0))
        self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]),
                                 padding=(0, pad1))
        self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]),
                                 padding=(0, pad1))
        self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1),
                                 padding=(pad0, 0))

    def forward(self, x):
        x_l = self.conv_l1(x)
        x_l = self.conv_l2(x_l)
        x_r = self.conv_r1(x)
        x_r = self.conv_r2(x_r)
        x = x_l + x_r
        return x


class _BoundaryRefineModule(nn.Module):
    def __init__(self, dim):
        super(_BoundaryRefineModule, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.relu(residual)
        residual = self.conv2(residual)
        out = x + residual
        return out


class GCN(nn.Module):
    def __init__(self, num_classes, input_size):
        super(GCN, self).__init__()
        self.input_size = input_size
        resnet = models.resnet152(pretrained=True)

        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1)
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.gcm1 = _GlobalConvModule(2048, num_classes, (7, 7))
        self.gcm2 = _GlobalConvModule(1024, num_classes, (7, 7))
        self.gcm3 = _GlobalConvModule(512, num_classes, (7, 7))
        self.gcm4 = _GlobalConvModule(256, num_classes, (7, 7))

        self.brm1 = _BoundaryRefineModule(num_classes)
        self.brm2 = _BoundaryRefineModule(num_classes)
        self.brm3 = _BoundaryRefineModule(num_classes)
        self.brm4 = _BoundaryRefineModule(num_classes)
        self.brm5 = _BoundaryRefineModule(num_classes)
        self.brm6 = _BoundaryRefineModule(num_classes)
        self.brm7 = _BoundaryRefineModule(num_classes)
        self.brm8 = _BoundaryRefineModule(num_classes)
        self.brm9 = _BoundaryRefineModule(num_classes)

        initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3,
                           self.brm4, self.brm5, self.brm6, self.brm7, self.brm8, self.brm9)

    def forward(self, x):
        # if x: 512
        fm0 = self.layer0(x)  # 256
        fm1 = self.layer1(fm0)  # 128
        fm2 = self.layer2(fm1)  # 64
        fm3 = self.layer3(fm2)  # 32
        fm4 = self.layer4(fm3)  # 16

        gcfm1 = self.brm1(self.gcm1(fm4))  # 16
        gcfm2 = self.brm2(self.gcm2(fm3))  # 32
        gcfm3 = self.brm3(self.gcm3(fm2))  # 64
        gcfm4 = self.brm4(self.gcm4(fm1))  # 128

        fs1 = self.brm5(F.upsample(gcfm1, fm3.size()[2:], mode='bilinear') + gcfm2)  # 32
        fs2 = self.brm6(F.upsample(fs1, fm2.size()[2:], mode='bilinear') + gcfm3)  # 64
        fs3 = self.brm7(F.upsample(fs2, fm1.size()[2:], mode='bilinear') + gcfm4)  # 128
        fs4 = self.brm8(F.upsample(fs3, fm0.size()[2:], mode='bilinear'))  # 256
        out = self.brm9(F.upsample(fs4, self.input_size, mode='bilinear'))  # 512

        return out


def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

Util

常用的工具函数

# -*- coding: utf-8 -*
import torch
from sklearn.metrics import f1_score
import numpy as np
import os


def weighted_BCELoss(output, target, weights=None):
    output = output.clamp(min=1e-5, max=1 - 1e-5)
    if weights is not None:
        assert len(weights) == 2

        loss = -weights[0] * (target * torch.log(output)) - weights[1] * ((1 - target) * torch.log(1 - output))
    else:
        loss = -target * torch.log(output) - (1 - target) * torch.log(1 - output)

    return torch.mean(loss)


def evaluate(y_true, y_pred):
    '''
    Calculate statistic matrix.

    Args:
        y_true:the pytorch tensor of ground truth
        y_pred:the pytorch tensor of prediction
    return:
        The F1 score
    '''
    y_true = y_true.numpy().flatten()
    y_pred = np.rint(y_pred.numpy().flatten())
    f1 = f1_score(y_true, y_pred)
    return f1


def save_model(model, save_dir, name):
    # save model
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    path = os.path.join(save_dir, name)
    print('Saving model to directory "%s"' % (path))
    torch.save(model.state_dict(), path)

Train

关键的训练。

引入必要的库,设置参数。 这里看到原作者还是做了很多实验的 开始只训练了AR的图片,在256×256上,接着做了数据增强,加入了NAR图片,接着又提升了图片的尺寸大小。

这里为了简单起见,直接使用512×512,且加入NAR图片来进行训练,结果不一定多好。

from dataset import IDRiD_sub1_dataset
from util import evaluate, save_model, weighted_BCELoss
from model import GCN
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
import time
import copy
import os

# gcn_v3 with weighted loss, AR only, 256x256
# gcn_v4 with random crop, 256x256
# gcn_v3_2 fine tune with AR and NAR
# gcn_v5 512x512

use_gpu = torch.cuda.is_available
save_dir = "./saved_models"
model_name = "test.pth"
data_train_dir = './data/train'
data_val_dir = './data/val'
batch_size = 8
num_epochs = 10
lr = 1e-4

创建dataloaders

def make_dataloaders(batch_size = batch_size):
    dataset_train = IDRiD_sub1_dataset(data_train_dir)
    dataset_val = IDRiD_sub1_dataset(data_val_dir)
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, num_workers=4)
    dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=4)
    dataloaders = {'train': dataloader_train, 'val': dataloader_val}
    print('Training data: %d\nValidation data: %d' % ((len(dataset_train)), len(dataset_val)))
    return dataloaders

训练模型

def train_model(model, num_epochs, dataloader,optimizer, schedular):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_f1 = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train','val']:
            if phase == 'train':
                model.train(True)
            else:
                #model.train(False)
                model.eval()
            running_loss = 0.0
            running_f1 = 0.0
            data_num = 0

            for idx, data in enumerate(dataloader[phase]):
                images,masks,name = data

                # weight for loss
                weights = [5,1]
                if use_gpu:
                    weights = torch.FloatTensor(weights).cuda()

                if use_gpu:
                    images = images.cuda()
                    masks = masks.cuda()

                if phase == 'train': # 这个有必要吗?
                    images, masks = Variable(images, volatile=False), Variable(masks, volatile=False)
                else:
                    images, masks = Variable(images, volatile=True), Variable(masks, volatile=True)

                optimizer.zero_grad() # 清空梯度

                # forward

                outputs = model(images)
                outputs = F.sigmoid(outputs) # cal bceloss, need apply sigmoid before 
                loss = weighted_BCELoss(outputs, masks, weights)

                # backword

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics

                running_loss += loss.data.item() * images.size(0)
                data_num += images.size(0)
                outputs = outputs.cpu().data
                masks = masks.cpu().data
                running_f1 += evaluate(masks, outputs) * images.size(0)

                # verbose # 输出
                if idx % 5 == 0 and idx != 0:
                    print('\r{} {:.2f}%'.format(phase, 100 * idx / len(dataloader[phase])), end='\r')
            epoch_loss = running_loss / data_num
            epoch_f1 = running_f1 / data_num

            if phase == 'val':
                schedular.step(epoch_loss)
            print('{} Loss: {:.4f} F1 score: {:.4f}'.format(phase, epoch_loss, epoch_f1))

            # deep copy the model
            if phase == 'val' and epoch_f1 > best_f1:
                best_f1 = epoch_f1
                best_model_wts = copy.deepcopy(model.state_dict())
                save_model(model, save_dir, model_name)
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best F1 score: {:.4f}'.format(best_f1))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

训练


dataloaders = make_dataloaders(batch_size=batch_size) # model model = GCN(4, 512) if use_gpu: model = model.cuda() # model = torch.nn.DataParallel(model).cuda() #model.load_state_dict(torch.load(os.path.join(save_dir, 'gcn_v5.pth'))) # training optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True) model = train_model(model, num_epochs, dataloaders, optimizer, scheduler) # save save_model(model, save_dir, model_name)

测试


import numpy as np from sklearn.metrics import precision_recall_fscore_support, confusion_matrix from PIL import Image import matplotlib.pyplot as plt import os use_gpu = torch.cuda.is_available save_dir = "/home/chaiwenjun/tmp/pycharm_project_754/saved_models" model_name = "f1_score-0.12404182312802403_0.02034974513212168test.pth" data_dir = '/home/chaiwenjun/tmp/pycharm_project_754/data/val' batch_size = 8

测试效果。

def show_image_sample():
    dataset = IDRiD_sub1_dataset(data_dir)

    model = GCN(4,512)
    if use_gpu:
        model = model.cuda()
    model.load_state_dict(torch.load(os.path.join(save_dir, model_name)))
    model.eval()
    for n in range(12): # 12个图
        #test
        full_image = np.zeros((3, 2848, 4288), dtype='float32')
        full_mask = np.zeros((4,2848,4288), dtype='float32')
        full_output = np.zeros((4,2848,4288), dtype='float32')

        title = ''
        for idx in range(9 * 6 * n,9 * 6 *(n+1) ):
            image, mask,name = dataset[idx]
            n = int(idx / (6 * 9))  # image index
            r = int((idx % (6 * 9)) / 9)  # row
            c = (idx % (6 * 9)) % 9  # column
            title = name[:-8]

            if use_gpu:
                image = image.cuda()
                mask = mask.cuda()
            image, mask = Variable(image, volatile=True), Variable(mask, volatile=True)

            #forward
            output = model(image.unsqueeze(0)) # 前面增加一个维度
            output = F.sigmoid(output)
            output = output[0]
            # 拼接输出
            if c < 8:
                if r == 5:
                    full_output[:,r*512:r*512+512 -224, c*512:c*512+512] = output.cpu().data.numpy()[:,:-224,:]
                    full_mask[:,r*512:r*512+512-224, c*512:c*512 + 512] = mask.cpu().data.numpy()[:,:-224,:]
                    full_image[:,r*512:r*512+512-224, c*512:c*512 + 512] = image.cpu().data.numpy()[:,:-224,:]
                else:
                    full_output[:, r * 512:r * 512 + 512, c * 512:c * 512 + 512] = output.cpu().data.numpy()
                    full_mask[:, r * 512:r * 512 + 512, c * 512:c * 512 + 512] = mask.cpu().data.numpy()
                    full_image[:, r * 512:r * 512 + 512, c * 512:c * 512 + 512] = image.cpu().data.numpy()

        full_image = full_image.transpose(1,2,0) # 把通道数放到最后一位
        MA = full_output[0]
        EX = full_output[1]
        HE = full_output[2]
        SE = full_output[3]

        plt.figure()
        plt.axis('off')
        plt.suptitle(title)
        plt.subplot(331)
        plt.title('image')
        fig = plt.imshow(full_image)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(332)
        plt.title('ground truth MA')
        fig = plt.imshow(full_mask[0])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(333)
        plt.title('ground truth EX')
        fig = plt.imshow(full_mask[1])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(334)
        plt.title('ground truth HE')
        fig = plt.imshow(full_mask[2])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(335)
        plt.title('ground truth SE')
        fig = plt.imshow(full_mask[3])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(336)
        plt.title('predict MA')
        fig = plt.imshow(MA)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(337)
        plt.title('predict EX')
        fig = plt.imshow(EX)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(338)
        plt.title('predict HE')
        fig = plt.imshow(HE)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(339)
        plt.title('predict SE')
        fig = plt.imshow(SE)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)

        plt.show()
show_image_sample()

总结

问题还是比较复杂,暂时的效果并没有很好,相对上一个简单的肝脏分割的问题。还有很多需要学习的地方。

源码

https://github.com/chaiwenjun000/pytorch_study/blob/master/pytorch_idrid.ipynb

参考

代码学习自github https://github.com/TRKuan/IDRiD

点赞
  1. 最乐园说道:

    看到你写的这些代码就知道你不简单了!