1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| class GroupMultiScaleCrop(object):
def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): self.scales = scales if scales is not None else [1, .875, .75, .66] self.max_distort = max_distort self.fix_crop = fix_crop self.more_fix_crop = more_fix_crop self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] self.interpolation = Image.BILINEAR
def __call__(self, img_group):
im_size = img_group[0].size
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] return ret_img_group
def _sample_crop_size(self, im_size): image_w, image_h = im_size[0], im_size[1]
base_size = min(image_w, image_h) crop_sizes = [int(base_size * x) for x in self.scales] crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
pairs = [] for i, h in enumerate(crop_h): for j, w in enumerate(crop_w): if abs(i - j) <= self.max_distort: pairs.append((w, h))
crop_pair = random.choice(pairs) if not self.fix_crop: w_offset = random.randint(0, image_w - crop_pair[0]) h_offset = random.randint(0, image_h - crop_pair[1]) else: w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
return crop_pair[0], crop_pair[1], w_offset, h_offset
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) return random.choice(offsets)
@staticmethod def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): w_step = (image_w - crop_w) // 4 h_step = (image_h - crop_h) // 4
ret = list() ret.append((0, 0)) ret.append((4 * w_step, 0)) ret.append((0, 4 * h_step)) ret.append((4 * w_step, 4 * h_step)) ret.append((2 * w_step, 2 * h_step))
if more_fix_crop: ret.append((0, 2 * h_step)) ret.append((4 * w_step, 2 * h_step)) ret.append((2 * w_step, 4 * h_step)) ret.append((2 * w_step, 0 * h_step))
ret.append((1 * w_step, 1 * h_step)) ret.append((3 * w_step, 1 * h_step)) ret.append((1 * w_step, 3 * h_step)) ret.append((3 * w_step, 3 * h_step))
return ret
|