PyTorch里构建Transforms进行医学病理图像数据增强

近期分析了一下PyTorch里torchvision里的transforms包,即,我们用PyTorch训练时常设置用于Normalizepreprocessing函数——

from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize((50, 50)),
transforms.ToTensor(),
transforms.Normalize(normMean, normStd)])

其实说是个transforms,我更把它理解为on-the-fly的数据增强

整个transforms包中,共有27个类:

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing"]

大致上可以分为四类——

  • 应用类;不是对图像本身进行变换,而是针对变换操作本身的处理;
  1. ["Compose", "RandomApply", "RandomChoice", "RandomOrder", "Lambda"]
  2. Lambda类存在意义很大,可以方便地将用户自定义的图像变换函数作为transform使用。如:将transforms.Lambda(lambda img: my_trans(img, 0.01))放进Compose中。
  • 基本变换类;进行基本的必备的变换;
  1. ["ToTensor", "ToPILImage", "Normalize", "Resize", "Pad"]
  2. 其中,pad操作中的"reflect", "symmetric"极好地弥补了PIL.Image中有些变换后边界只能置固定值的缺陷。
  • 选择类;图像本身没有变换,只是做了原图像内容的不同选择;
  1. ["CenterCrop", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "RandomRotation", "Grayscale", "RandomGrayscale", "RandomErasing"]
  2. 当然,这里有点冲突,形如RandomResizedCrop,同时包含图像内容选择和resize,我私认为主体还是图像内容选择。
  • 复杂变换类
  1. ["LinearTransformation", "ColorJitter", "RandomAffine", "RandomPerspective"]
  2. 包含了白化变换、HSV颜色空间的H和S通道扰动、亮度对比度扰动、仿射变换、透视变换。

除此以外,我也实现了几个特别适用于医学病理图像增强的几个增强类,包含有:

  • HEDJitter:HE染色病理图像的HED空间随机扰动增强
    采用颜色反卷积方法,将HE图像从RGB空间变换到HED空间,然后针对HED空间每一通道添加随机扰动,最后再变换到RGB空间。 其中,HED空间通道扰动策略:$s’ = \alpha * s + \beta$
  • RandomElastic:随机弹性变换
    对每个像素点分别进行xy的扰动偏差,具体可参考原论文section2
  • RandomAffineCV2:随机仿射变换(基于opencv方法)
    该仿射变换和torchvision中的RandomAffine本质一样,只不过实现上是PIL.Imagecv2的区别。采用opencv方法,可以对变换后的无内容边界区域进行镜像填充,而非PIL.Image那般死板只能填充固定像素值;
  • RandomGaussBlur:随机高斯模糊
    基于高斯滤波方法,对图像进行不同程度的模糊,适用于病理全扫描切片图像中的部分区域聚焦不够的情况;
  • AutoRandomRotation:自动随机旋转
    这个自动有点牵强,其实就是将torchvisionRandomRotation稍微改动了一下,从{0,90,180,270}这个集合中随机选择一个角度进行旋转,没有其他角度。这种做法主要是为了避免PIL.Image中边界区域只能填充固定像素值的问题(当然也可以用pad操作解决)。

代码已经开源托管在Github: Augmentation-PyTorch-Transforms ,并且自认为在README中清晰地介绍了增加的图像增强类的使用方式,或者可以查看Example_Transforms 参考,欢迎star,fork。

一个基本的myTransforms示例

preprocess = myTransforms.Compose([
myTransforms.RandomChoice([myTransforms.RandomHorizontalFlip(p=1),
myTransforms.RandomVerticalFlip(p=1),
myTransforms.AutoRandomRotation()]), # above is for: randomly selecting one for process
# myTransforms.RandomAffineCV2(alpha=0.1), # alpha \in [0,0.15],
# myTransforms.RandomAffine(degrees=0, translate=[0, 0.2], scale=[0.8, 1.2], shear=[-10, 10, -10, 10], fillcolor=(228, 218, 218)),
myTransforms.RandomElastic(alpha=2, sigma=0.06, mask=None),
myTransforms.ColorJitter(brightness=(0.65, 1.35), contrast=(0.5, 1.5)),
myTransforms.RandomChoice([myTransforms.ColorJitter(saturation=(0, 2), hue=0.3),
myTransforms.HEDJitter(theta=0.05)]),
# myTransforms.RandomGaussBlur(radius=[0.5, 1.5]),
myTransforms.ToTensor(), #operated on original image, rewrite on previous transform.
myTransforms.Normalize([0.6270, 0.5013, 0.7519], [0.1627, 0.1682, 0.0977])])
print(preprocess)