Segmenting medical images with SAM model using Python
In this tutorial, we will take advantage of the amazing Segment Anything Model (SAM) from META AI to segment a lesion from a CT scan. The segmented lesion can be stored as an individual image.
We initially install the library using pip
!pip install git+https://github.com/facebookresearch/segment-anything.git
Then we import the libraries
import torchvision
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from skimage import data
from skimage.io import imread,imsave
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import gdown
The weights are stored in the "SAM_weights" file with the help of gdown library
def download_weights(url)
gdown.download(url,'SAM_weights', quiet=False)
download_weights(url)
Afterward, we upload the image to be segmented. In our case, we have a CT scan with a glioblastoma. We exhibit the CT scan
image = imread('PATH/TO/THE/IMAGE')
plt.figure(figsize=(5,5))
plt.imshow(image)
plt.axis('off')
plt.show())
The settings of the model are the next step
device = "cuda"
sam_checkpoint = "SAM_weights"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
And now we are ready to use the SamAutomaticMaskGenerator method of SAM model to generate the masks.
mask_generator_ = SamAutomaticMaskGenerator
model=sam,
points_per_side=32,
pred_iou_thresh=0.9,
stability_score_thresh=0.96,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100,
)
masks = mask_generator_.generate(image)(
Now is the real magic of the SAM. With visualize_segmentations function we demonstrate each segmented area with the corresponding number. The numbers are sorted from the biggest area( number zero) to the smallest.
def visualize_segmentations(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
text_color = (0,0, 0) # black color
text_scale = 1
text_thickness = 1
for j,ann in enumerate(sorted_anns):
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
# Compute the center of mass of the binary mask
text_size, _ = cv2.getTextSize(str(j), cv2.FONT_HERSHEY_SIMPLEX, text_scale, text_thickness)
M = cv2.moments(m.astype('uint8'))
center_x = int(M['m10'] / M['m00'])
center_y = int(M['m01'] / M['m00'])
text_x = center_x - text_size[0]//2
text_y = center_y + text_size[1]//2
for i in range(3):
img[:,:,i] = color_mask[i]
cv2.putText(img, str(j), (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, text_scale, text_color, text_thickness)
img = np.dstack((img, m*0.35))
plt.imshow(img)
plt.figure(figsize=(10,10))
plt.imshow(image)
visualize_segmentations(masks)
plt.axis('off')
plt.show()
As we see from the result the lesion belongs to area #2.
And now we are ready to isolate whatever area we want. Let's isolate and save the lesion.
def find_the_segmentation(index_):
global segmented_area
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
print('the segmentation area has {} pixels'.format(sorted_anns[index_]['area']))
mask = np.expand_dims(sorted_anns[index_]['segmentation'],axis=-1)
segmented_area = (image*mask).astype('uint8')
segmented_area[segmented_area==0]=1
segmented_area *=255
imsave('lesion.png',segmented_area)
plt.imshow(segmented_area)
plt.axis('off')
plt.show()
find_the_segmentation(2)
I hope you know got an idea of how to do it. The SAM model has many things to be investigated and thus its capabilities are endless!
👍
for anyone interested, this is the google collab link: https://colab.research.google.com/drive/1xx8ABlrkKjUfDV8DO33dVBhNIzBYN4lr?usp=sharing