PyTorch: Semantic Segmentation in ONE line of code

PyTorch: Semantic Segmentation in ONE line of code

Semantic Segmentation is an image analysis task in which each pixel in the image is classified into a certain class.

  • Input: RGB image with width (W), Height (H) and color depth (3)
  • Output: Tensor with width (W), Height (H) and depth equal to number of the possible classes. Where the probability distribution for each pixel over classes is defined.

Our goal is to use simple code to perform this task.

# get the semantic segmentation of an  input image
seg = get_segment('./image.png')
# vizualize the segmentation as an RGB image (optional) 
rgb = seg_to_rgb(seg)        
No alt text provided for this image

1- Load the model

Here we use the FCN model for semantic segmentation

fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()        

2- Define the transformations needed

Adjust the image as the model expects

trf = T.Compose([T.Resize(256),
                 T.ToTensor(), 
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])        

3- Define the Segmentation function

The takes the image as input and produces the probability distribution for each pixel in the input image

def get_segment(image_path):
  img = Image.open(image_path)
  inp = trf(img).unsqueeze(0)
  seg = fcn(inp)['out']
  return seg        

4- Visualize the output as an image

By selecting the most probable class per pixel (the max probability of the 21 predefined classes in our case). Here we define a function that takes the [W x H x 21] and returns the corresponding RGB visualization [W x H x 3].

# Define the helper function
def seg_to_rgb(image, nc=21):
    image = torch.argmax(image.squeeze(), dim=0).detach().cpu().numpy()
    label_colors = np.array([(0, 0, 0),  # 0=background# 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
                             (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
                             # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
                             (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
                             # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
                             (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
                             # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
                             (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])

    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]

    rgb = np.stack([r, g, b], axis=2)
    return rgb        

Examples:

# get the image
!wget -nv https://static.independent.co.uk/s3fs-public/thumbnails/image/2018/04/10/19/pinyon-jay-bird.jpg -O bird.png
img = Image.open('./bird.png')
plt.imshow(img); plt.show()

# segment and vizualize
seg = get_segment('./bird.png')
rgb = seg_to_rgb(seg)
plt.imshow(rgb); plt.show()        
No alt text provided for this image
!wget -nv https://user-images.githubusercontent.com/3080674/29361099-52eb370c-8286-11e7-8274-ceb4895fe0b9.png -O img01.png

# segment and vizualize 
seg = get_segment('./img01.png')
rgb = seg_to_rgb(seg)
plt.imshow(rgb); plt.show()        
No alt text provided for this image

Colab code can be found here

Regards

To view or add a comment, sign in

More articles by Ibrahim Sobh - PhD

Others also viewed

Explore content categories