Oxygen Chrome

공부/Deep Learning

[SAM2] Segment Anything 모델 구현 및 시각화 | 이미지 내 모든 마스크/바운딩박스 생성하고 추출하기

aribae 2024. 11. 15. 16:09

 

Ubuntu 24.04 / Python 3.12 / NVIDIA GeForce RTX 4090 환경에서 구현되었습니다.


논문

https://scontent-ssn1-1.xx.fbcdn.net/v/t39.2365-6/464917098_581932941165933_4465312900778079623_n.pdf?_nc_cat=105&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=IVHJDWR-y3gQ7kNvgFEPeY_&_nc_zt=14&_nc_ht=scontent-ssn1-1.xx&_nc_gid=AQSrWWdiVsabchv5sb78BW8&oh=00_AYCIP9z5dqCVB0BhW8s4zb2NtnvFgvN5tu4bVq2_hJm8mg&oe=673C81F2

 

깃허브

https://github.com/facebookresearch/sam2

 

GitHub - facebookresearch/sam2: The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use th...

github.com

 

논문 리뷰는 잘 해두신 분들이 많으니 .. 고걸 참고하셔용 - ̗̀( ˶'ᵕ'˶) ̖́-


 

1. Intro

오늘은 귀여운 무뎅이 사진으로 SAM2 Segmentation 마스크 및 바운딩박스를 생성해보아요 (◕ᴗ̵◕)♡

모든 코드는 SAM2 깃허브를 바탕으로 구현되었으나, 일부 다른 부분이 있으니 확인해주세요 ! 

 

전체 코드를 사용하고 싶다면? 👉 https://github.com/usnuni/segment-anything2 

원본 이미지 / 생성된 마스크 / 마스크의 바운딩 박스

 

 


2. Download SAM2

본인이 사용하고자 하는 경로에 맞춰 SAM2를 클론해줍니다.

git clone https://github.com/facebookresearch/sam2.git && cd sam2
pip install -e .

 

2.1 Download Checkpoints

모델을 직접 학습하는게 아니니 체크포인트도 다운로드해줍니다.

cd checkpoints && \
./download_ckpts.sh && \

 

 


 

3. 필요한 라이브러리 Import

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from utils import show_anns

 

3.1 Segmentation 마스크 시각화 코드

마스크를 이미지에 랜덤한 컬러로 시각화를 해주는 코드를 작성합니다.

아래는 제가 사용한 코드를 공유드려요.

import matplotlib.pyplot as plt

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 
            
    return ax, img

 

 


 

4. 모델 불러오기

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

 

ckpt, config 는 본인이 다운로드한 경로에 맞춰 변경해주세요. ( ᐛ )σ

# load model
sam2_checkpoint = "path/to/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "path/toconfigs/sam2.1/sam2.1_hiera_l.yaml"
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2)

 

 


 

5. 이미지 불러오기

# load image 
image_path = 'path/to/mudeng.jpeg'
# open image
image = Image.open(image_path)
# np.array 변경
image = np.array(image)

 

5.1. 결과 저장 경로 설정 (선택)

save_path = 'path/to/save'

 

 


 

6. 마스크 생성하기

# generate mask
masks = mask_generator.generate(image)

 

6.1 마스크 이미지 상에 띄우기

실행하게 되면 아래와 같은 이미지가 플롯되고, 이미지가 저장됩니다. 

plt.figure(figsize=(20, 20))
plt.imshow(image)
ax, img = show_anns(masks)
ax.imshow(img)
plt.axis('off')
plt.savefig(os.path.join(save_path, 'sam2_mask.png'))
plt.show()

 

결과

 

 

 


 

7. 각각의 마스크 크기에 맞는 바운딩박스 생성하기

bboxes = []
for mask in masks:
    bbox = mask['bbox']
    # bbox = [x, y, w, h]
    bboxes.append(bbox)

bboxes = np.array(bboxes)
bbox_areas = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1])
bbox_areas = bbox_areas.astype(int)

 

7.1 바운딩박스 이미지 상에 띄우기

plt.figure(figsize=(20, 20))
plt.imshow(image)
for bbox in bboxes:
    x, y, w, h = bbox
    plt.plot([x, x+w, x+w, x, x], [y, y, y+h, y+h, y], color='r', linewidth=2)
plt.axis('off')
plt.savefig(os.path.join(save_path, 'sam2_bbox.png'))
plt.show()

 

 


 

8. 각각의 마스크/바운딩박스 저장하기

플롯을 할 때는 int로 변환하지 않아도 되는데,, 

TypeError: slice indices must be integers or None or have an __index__ method

에러가 출력되어 int로 변환을 해주니 잘 저장되었습니다.

for i, bbox in enumerate(bboxes):
    x, y, w, h = bbox
    x, y, w, h = int(x), int(y), int(w), int(h)
    cropped = image[y:y+h, x:x+w]
    cropped = Image.fromarray(cropped)
    cropped.save(os.path.join(save_path, f'{i+1}.png'))

 

결과

 

 

 


 

Discussion

저는 무작위한 패턴과 같은 복잡한 이미지를 주로 다루고 있습니다.

(보안상 이미지를 공개할 수 는 없지만.. 아래와 비슷한 느낌입니다.) 

 

포스팅에 사용한 이미지를 보아도 모든 객체를 구분하지는 못하고 있는데,

더 많은 포인트를 잡아 segmentation 할 수 있도록 코드를 좀 더 들여다 봐야겠습니다.

 

어떤 부분을 조정하면 좋은지 아시는 분이 있다면 댓글 남겨주시면 감사하겠습니다. ( •̀ᴗ•́ )و ̑̑

 

 

 

그럼, 읽어주셔서 감사합니다.