PAN或者说TEATSMTSN这一系列网络所用的唯一的数据增强方法:GroupMultiScaleCrop

something-something的数据集为例,因为这是我在用的数据集,其它场景型数据集多了一个水平翻转,因为something-something有左右区分,所以不行。

model.py的最后:

1
return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])])

具体调用的代码在transforms.py,加了些注释:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class GroupMultiScaleCrop(object):

def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
self.scales = scales if scales is not None else [1, .875, .75, .66]
self.max_distort = max_distort
self.fix_crop = fix_crop
self.more_fix_crop = more_fix_crop
self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
self.interpolation = Image.BILINEAR

def __call__(self, img_group):

im_size = img_group[0].size # 因为一组图都是一样大的,从第一张图就能获取到大小

crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) # 随机选择一种剪切方式
crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] # 切图
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) # 然后恢复成网络输入的大小,224*224,用双线性插值的方法
for img in crop_img_group]
return ret_img_group

def _sample_crop_size(self, im_size):
image_w, image_h = im_size[0], im_size[1] # 获取图像大小

# find a crop size
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in self.scales] # 按照宽高比较窄的那一边,乘上多个尺度,就是要修剪的大小
crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] # input_size是224,如果大小等于224+-3那就224,不然就是修剪之后的大小
crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]

pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= self.max_distort: # max_distort=1,相当于允许i和j之间差一,就[1, .875, .75, .66]之间
pairs.append((w, h))

crop_pair = random.choice(pairs) # 随机选一个切法
if not self.fix_crop:
w_offset = random.randint(0, image_w - crop_pair[0])
h_offset = random.randint(0, image_h - crop_pair[1])
else:
w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) # 默认走这里

return crop_pair[0], crop_pair[1], w_offset, h_offset

def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) # 传进去的是原始的图像大小和要切成的图像高和宽
return random.choice(offsets) # 13种偏移里选一种,而不是上面那种全选的方式

@staticmethod
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4

ret = list()
ret.append((0, 0)) # upper left
ret.append((4 * w_step, 0)) # upper right
ret.append((0, 4 * h_step)) # lower left
ret.append((4 * w_step, 4 * h_step)) # lower right
ret.append((2 * w_step, 2 * h_step)) # center

if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center

ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter

return ret

入口都是__call__方法,Image.BILINEAR双线性插值。