Skip to main content

Things Learned from implementing SRGAN in PyTorch

https://twitter.com/tim_dettmers/status/1059539322985054208?lang=en
PyTorch has .half() — it should work out of the box with that

For faster input output speed:
https://github.com/xinntao/BasicSR/wiki/Faster-IO-speed
1. Put data in SSD
2. Crop the images beforehand so that you won't have to load full images during training - since you are going to train multiple times - for debugging and what not so good approach.
3. Convert to lmdb - it's faster.

For my case:
As like author, I generated 480*480 sub-images with a sliding window of step = 240.
Initially I had 800 training images, 800 items, totalling 3.5 GB
Now after cropping I had: 32,208 items, totalling 12.3 GB, each of size 480*480*3


cv2 reads image in BGR format and the output is numpy unsigned integer 8, so be careful, unsigned integer means if you subtract 62-91 you will get 227.
x = np.uint8(62)
y = np.uint8(91)
print(x-y)

To be safe, before doing any operation on cv2 do img = img.astype(np.float64)


Pytorch Filter shape:
self.weight = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size) )
self.bias = Parameter(torch.Tensor(out_channels))

// = floor division
*kernel_size usually is kernel_size, kernel_size

https://github.com/pytorch/pytorch/blob/be1ef5e4a40ef10adffb46f5b6028d19dc22aa7c/torch/nn/modules/conv.py#L32-L33