Skip to content
Snippets Groups Projects
Select Git revision
  • ff8a13295b39565e5c71fe983b37e21e00686fb8
  • main default
2 results

preprocessing.py

Blame
  • user avatar
    Antoine Gaudron-desjardins authored
    ff8a1329
    History
    preprocessing.py 1.46 KiB
    import cv2
    import numpy as np
    
    
    def norm_by_imagenet(img):
        if len(img.shape) == 3:
            img = img / 255.0
            img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229
            img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224
            img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225
            return img
        elif len(img.shape) == 4 or len(img.shape) == 1:
            # In SHA, shape of images varies, so the array.shape is (N, ), that's the '== 1' case.
            imgs = []
            for im in img:
                im = im / 255.0
                im[:, :, 0] = (im[:, :, 0] - 0.485) / 0.229
                im[:, :, 1] = (im[:, :, 1] - 0.456) / 0.224
                im[:, :, 2] = (im[:, :, 2] - 0.406) / 0.225
                imgs.append(im)
            return np.array(imgs)
        else:
            print('Wrong shape of the input.')
            return None
    
    def fix_singular_shape(img, unit_len=16):
        """
        Some network like w-net has both N maxpooling layers and concatenate layers,
        so if no fix for their shape as integeral times of 2 ** N, the shape will go into conflict.
        """
        hei_dst, wid_dst = img.shape[0] + (unit_len - img.shape[0] % unit_len), img.shape[1] + (unit_len - img.shape[1] % unit_len)
        if len(img.shape) == 3:
            img = cv2.resize(img, (wid_dst, hei_dst), interpolation=cv2.INTER_LANCZOS4)
        elif len(img.shape) == 2:
            GT = int(round(np.sum(img)))
            img = cv2.resize(img, (wid_dst, hei_dst), interpolation=cv2.INTER_LANCZOS4)
            img = img / (np.sum(img) / GT)
        return img