Select Git revision
preprocessing.py
preprocessing.py 1.46 KiB
import cv2
import numpy as np
def norm_by_imagenet(img):
if len(img.shape) == 3:
img = img / 255.0
img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229
img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224
img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225
return img
elif len(img.shape) == 4 or len(img.shape) == 1:
# In SHA, shape of images varies, so the array.shape is (N, ), that's the '== 1' case.
imgs = []
for im in img:
im = im / 255.0
im[:, :, 0] = (im[:, :, 0] - 0.485) / 0.229
im[:, :, 1] = (im[:, :, 1] - 0.456) / 0.224
im[:, :, 2] = (im[:, :, 2] - 0.406) / 0.225
imgs.append(im)
return np.array(imgs)
else:
print('Wrong shape of the input.')
return None
def fix_singular_shape(img, unit_len=16):
"""
Some network like w-net has both N maxpooling layers and concatenate layers,
so if no fix for their shape as integeral times of 2 ** N, the shape will go into conflict.
"""
hei_dst, wid_dst = img.shape[0] + (unit_len - img.shape[0] % unit_len), img.shape[1] + (unit_len - img.shape[1] % unit_len)
if len(img.shape) == 3:
img = cv2.resize(img, (wid_dst, hei_dst), interpolation=cv2.INTER_LANCZOS4)
elif len(img.shape) == 2:
GT = int(round(np.sum(img)))
img = cv2.resize(img, (wid_dst, hei_dst), interpolation=cv2.INTER_LANCZOS4)
img = img / (np.sum(img) / GT)
return img