Meta AI Introduces Revolutionary Image Segmentation Model Trained on 1 Billion Masks

Author:Murphy  |  View: 21581  |  Time: 2025-03-23 19:01:38

Introduction

After revolutionary step made by OpenAI's ChatGPT in NLP, AI progression continues and Meta AI introduces astonishing progress in computer vision. Meta AI research team introduced the model called Segment Anything Model (SAM) and a dataset of 1 Billion masks on 11 Million images. Segmentation of an image is identifying which image pixels belong to an object.

Demo of image segmentation by ai.facebook.com

The proposed project mainly includes three pillars: Task, Model and Data.

1. Segment Anything Task

The main goal for Meta AI team was to create a promptable Image Segmentation model that would work with user input prompt as it is working with ChatGPT. Therefore, they came up with the solution to integrate user input with the image to produce segmentation masks. Segmentation prompt can be any information indicating what to segment in an image. For example, set of foreground or background point, a box, free-form text etc. So the model's output is a valid segmentation mask given any user defined prompt.

2. Segment Anything Model

The promotable Segment Anything Model (SAM) has three components shown in the figure bellow.

Segment anything model workflow by ai.facebook.com

A high level of model architecture consists of an image encoder, prompt encoder, and mask decoder. For the image encoder they have used MAE [1] pre-trained model that has Vision Transformer(ViT) [2] architecture. ViT models are state-of-the-art models in image classification and segmentation tasks. As for the prompts, they divided them into two types – one type of prompts is sparse such as points, boxes, and text and another type is dense such as masks. The prompt encoder step creates embeddings for each type of prompt. As for the mask decoder, it just maps image embeddings, prompt embeddings, and output tokens to a mask.

3. Segment Anything Data

3.1 Segment Anything Data Engine

Garbage in garbage out (image by the author)

The principle – garbage in garbage out – applies to the AI domain as well. If the input data is poor quality, a model-generated result will not be good as well. That is why, the Meta team tried to select high-quality images to train their model. The team has created a data engine to filter the raw image dataset. Creating a data engine is divided into three stages.

  1. Manual stage: Human professional annotators were involved to label masks on the image manually.
  2. Semi-automatic stage: They trained the model on annotated images and made an inference on the rest of the images. Then, human annotators were asked to label additional unlabeled objects that were not detected by the model or correct segments with low confidence scores.
  3. Fully automatic stage: This stage includes automatic mask generation and automatic filtering stage which tries to leave non-ambiguous masks and keep the masks based on confidence, stability, and size.

3.2 Segment Anything Dataset

The Segment Anything Data Engine created a 1 Billion masks dataset (SA-1B) on 11 Million diverse, high resolution (3300×4900 pixels on average) and licensed images. It is worth mentioning that 99.1% of masks were generated automatically, however the quality is so high because they are carefully selected.

Conclusion – Why is the model revolutionary

Meta AI team together with other Huge company teams are doing great progress in development of AI. The Segment Anything Model (SAM) has capabilities to power applications in numerous domains that require finding and segmenting any object in any image. For example:

  • SAM could be a component of a large multimodal model that integrated images, text, audio etc.
  • SAM could enable selecting an object AR/VR domain based on a user's gaze and then "lifting" it into 3D
  • SAM can improve creative applications such as extracting image regions for video editing.
  • and many more.

Image Segmentation Demo

In this part, I will try to use official GitHub code to play with the algorithm using Google Colab and perform two types of segmentation on the image. First, I will do segmentation with user-defined prompt and second I will do fully automatic segmentation.

Part 1: Image segmentation using user-defined prompt

  1. Set up (import libraries and installations)
Python">from IPython.display import display, HTML
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

display(HTML(
"""

  Open In Colab

"""
))

using_colab = True

if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg

    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
  1. Helper functions to plot masks, point and boxes on the image.
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
  1. Input image (initial image to segment). Lets try to select the mask of a first bag of a groceries.
image = cv2.imread('/content/images/groceries.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
plt.figure(figsize=(5,5))
plt.imshow(image)
plt.axis('on')
plt.show()
Input image (image from Facebook research)
  1. Load the pretrained model called _sam_vit_h_4b8939.pth which is a default model. There are another lighter version of models such as sam_vit_l_0b3195.pth and sam_vit_b_01ec64.pth_
sam_checkpoint = "/content/sam_vit_h_4b8939.pth"
device = "cuda"
model_type = "default"

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)
  1. Visualize the point on the image(user prompt) which will help to identify our target object – first glossary bag.
input_point = np.array([[465, 300]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
Input image with user prompt (image from Facebook research)
  1. Make a prediction to generate a mask of the object.
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)
print(masks.shape)  # (number_of_masks) x H x W
  1. Show top 3 generated mask. When _multimaskoutput=True, the algorithm returns three mask. Later we can select the one with the highest score.
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show() 
Prediction results (image by the author)

The highlighted objects are the masks predicted by the model. As the result shows, the model generated three output masks with following prediction scores: mask1 – 0.990, Mask2 – 0.875 and Mask3 – 0.827. We select mask1 which has the highest score. Voila!!!! Model's prediction mask is out target object that we wanted to segment initially. The result is amazing, the model works quite well!

Part 2: Fully Automatic Image segmentation – Cont.

  1. Plotting function of segments
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))
  1. Generate masks automatedly
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

masks = mask_generator.generate(image)
print(len(masks))
  1. Show the result
plt.figure(figsize=(5,5))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 
Automatic segmentation result by SAM (image by the author)

The algorithm identified 137 different objects (masks) using default parameters. Each mask contains information about segment area, bounding box coordinates, prediction score and stability score that could be used to filter out unwonted segments.


I hope you enjoyed it and now can start creating beautiful apps yourself. If you have any questions or would like to share your thoughts about this article, feel free to comment, I will be happy to answer.

If you want to support my work directly and also get unlimited access on Medium articles, become a Medium member using my referral link here. Thank you a million times and have a nice day!

Join Medium with my referral link – Gurami Keretchashvili

References

[1] Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollar, and Ross Girshick. Masked autoencoders are scalable vision learners. CVPR, 2022.

[2] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16×16 words: Transformers for image recognition at scale. ICLR, 2021.

[3] Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. Segment Anything, 2023

My Previous articles about ML deployment

How to Deploy Machine Learning models? End-to-End Dog Breed Identification Project!

How to Deploy Machine Learning Models

Tags: AI Computer Vision Data Science Image Segmentation Python

Comment