Implementing UNet in Pytorch

--

When learning image segmentation UNet serves as one of the basic models for the segmentation. UNet is one of the most used models for image segmentation. You can see people are making a lot of changes in the Original UNet architecture like using Resnet etc. but let’s implement the Original UNet Architecture. in 7 Steps

Architecture of the Unet.

The architecture of the Unet can be divided into two part Left (Contracting path) & Right (Expansion path).

The Left part is just a simple convolution network. In the left part Two 3x3 Convolution layers followed by a Relu activation function are stacked together (Sequentially) and a 2x2 maxpool layer is applied after that(red arrow in image) First vertical bar in the left side in the image is not a layer but represents the input.(input image tile)

The Right part is where interesting things happen. Right part also uses Two 3x3 Convolution layers stacked together (Sequentially) like left side but there is no maxpool layer used instead a 2x2 Transpose convolution layer is used (green arrow in image ). During the expansion path, we will take the image (copy ) from the left side and combine it with the image on the right (grey arrow in the image). Remember a sequential 3x3 convolution layers are also used in the right side so the input for that will be combination of the image from right and its previous layer (half white and blue box in the right side of the image is the combination).

The output layer on the right side an extra convolution layer is applied (output segmentation map ).

So let’s just code the Unet architecture.

Full code : Github

As we have seen in the architecture a 3x3 double convolution layer followed by the Relu activation function is used on both right and left side.

Step 1.

a dual_conv() function is created with the in & out channel parameters.
inside the function, a Sequential layer of two convolution layers with kernel size 3 (3x3 conv) each followed by a
relu activation is made.
dual_conv() returns the conv a sequential layer

Step 2.

We will create a class Unet() and make the layers of the left side and a maxpool layer(the red arrow in image). In each layer, we use dual_conv() as it uses dual convolution. Let's just name the layer as dwn_conv (5 layers are on the left side).

Step 3.

Make a forward() function  inside the class in which we will forward pass the input(image) to the left side layers

WALAAHH… left side is complete

After passing the image from left side the interesting part came Right side of the architecture.So let’s just implement that.

Step 4.

Now let’s declare the 4 layers of the right side in the __init__() function of the class and the last 1x1 conv. output layer. 2x2 transpose convolution is used instead of maxpool as in the left side

Step 5.

As we seen in the architecture the input image on the right side is combination of the image from left side (grey arrow in image) and its previous layer. But for combining the imaged it has to be the same size images so lets make a function out side the class to crop the image.

what happens in crop_tensor()   images = tensortensor = image from the left side which needs to be cropped
target_tensor = image on the right side whose size has to be matched by cropping left side image
take the size of both the tensors in target_size and tensor_size .
[2] takes only last value of tensor width , as height and width are same eg:torch.Size([1, 512, 64, 64]) so it take [2] = 64
now we got the size of both the images we will subtract the size of lower tensor'target_size' from bigger one 'ternsor_size' .Suppose target_size=56 and tensor_size=64 so delta(subtracted size) will be 8 but we will crop image from all corners 'height' * 'width' so we will divide the delta by 2 so that height and width can be cropped equally
8 =>h*W = 4*4

now return the cropped tensor
[:, :,] = all dimensions
[delta:tensor_size-delta, delta:tensor_size-delta] = cropped image
[4:64-4, 4:64-4] => 4:60, 4:60 in above example we need 56x56 img
see the below image height as an example

Step 6.

Forward pass on the right side we will make this in the forward() function inside the class

first is transpose layer x takes input x9 the last layer of left side. 
now combine the images of x and layer in front of it(left side), but wait the size of two images are different so we will crop the image using crop_tensor() function, size of x is smaller then x7
print(x7.size()) :torch.Size([1, 512, 64, 64])
print(x.size()) :
torch.Size([1, 512, 56, 56])
combine both the images using torch.cat() and pass it to up_conv()

Step 7.

Lets just make a image (as image is just a tensor ) by using torch.rand()

572 x572 image height x width as Unet take 572x572 image as input

and pass it to the model.

Full code : Github

Refrence : https://youtu.be/u1loyDCoGbE

Connect with me on linkedin : https://www.linkedin.com/in/rakshit01/

Don’t forget to give us your 👏 !

--

--