first commit
This commit is contained in:
226
vr180_matting/sam2_wrapper.py
Normal file
226
vr180_matting/sam2_wrapper.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
SAM2_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAM2_AVAILABLE = False
|
||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||
|
||||
|
||||
class SAM2VideoMatting:
|
||||
"""SAM2-based video matting with memory optimization"""
|
||||
|
||||
def __init__(self,
|
||||
model_cfg: str = "sam2_hiera_l.yaml",
|
||||
checkpoint_path: str = "sam2_hiera_large.pt",
|
||||
device: str = "cuda",
|
||||
memory_offload: bool = True,
|
||||
fp16: bool = True):
|
||||
if not SAM2_AVAILABLE:
|
||||
raise ImportError("SAM2 not available. Please install sam2 package.")
|
||||
|
||||
self.device = device
|
||||
self.memory_offload = memory_offload
|
||||
self.fp16 = fp16
|
||||
self.predictor = None
|
||||
self.inference_state = None
|
||||
self.video_segments = {}
|
||||
|
||||
self._load_model(model_cfg, checkpoint_path)
|
||||
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor with optimizations"""
|
||||
try:
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg,
|
||||
checkpoint_path,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Enable memory optimizations
|
||||
if self.memory_offload:
|
||||
self.predictor.fill_hole_area = 8
|
||||
|
||||
if self.fp16 and self.device == "cuda":
|
||||
self.predictor.model.half()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||
|
||||
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
|
||||
"""Initialize video inference state"""
|
||||
if self.predictor is None:
|
||||
raise RuntimeError("SAM2 model not loaded")
|
||||
|
||||
# Create temporary directory for frames if needed
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=None,
|
||||
video_frames=video_frames,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
async_loading_frames=True
|
||||
)
|
||||
|
||||
def add_person_prompts(self,
|
||||
frame_idx: int,
|
||||
box_prompts: np.ndarray,
|
||||
labels: np.ndarray) -> List[int]:
|
||||
"""
|
||||
Add person detection prompts to SAM2
|
||||
|
||||
Args:
|
||||
frame_idx: Frame index to add prompts
|
||||
box_prompts: Bounding boxes (N, 4)
|
||||
labels: Prompt labels (N,)
|
||||
|
||||
Returns:
|
||||
List of object IDs
|
||||
"""
|
||||
if self.inference_state is None:
|
||||
raise RuntimeError("Video state not initialized")
|
||||
|
||||
object_ids = []
|
||||
|
||||
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
||||
obj_id = i + 1 # Start from 1
|
||||
|
||||
# Add box prompt
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=self.inference_state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=obj_id,
|
||||
box=box,
|
||||
)
|
||||
|
||||
object_ids.extend(out_obj_ids)
|
||||
|
||||
return object_ids
|
||||
|
||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Propagate masks through video
|
||||
|
||||
Args:
|
||||
start_frame: Starting frame index
|
||||
max_frames: Maximum number of frames to process
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||
"""
|
||||
if self.inference_state is None:
|
||||
raise RuntimeError("Video state not initialized")
|
||||
|
||||
video_segments = {}
|
||||
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||
self.inference_state,
|
||||
start_frame_idx=start_frame,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=False
|
||||
):
|
||||
frame_masks = {}
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
frame_masks[out_obj_id] = mask
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
|
||||
# Memory management: release old frames periodically
|
||||
if self.memory_offload and out_frame_idx % 100 == 0:
|
||||
self._release_old_frames(out_frame_idx - 50)
|
||||
|
||||
return video_segments
|
||||
|
||||
def _release_old_frames(self, before_frame_idx: int):
|
||||
"""Release old frames from memory"""
|
||||
try:
|
||||
if hasattr(self.predictor, 'release_old_frames'):
|
||||
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to release old frames: {e}")
|
||||
|
||||
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
|
||||
"""Combine masks from multiple objects into single mask"""
|
||||
if not frame_masks:
|
||||
return None
|
||||
|
||||
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
|
||||
|
||||
for obj_id, mask in frame_masks.items():
|
||||
if mask.ndim == 3:
|
||||
mask = mask.squeeze()
|
||||
combined_mask = np.logical_or(combined_mask, mask)
|
||||
|
||||
return combined_mask
|
||||
|
||||
def apply_mask_to_frame(self,
|
||||
frame: np.ndarray,
|
||||
mask: np.ndarray,
|
||||
output_format: str = "alpha",
|
||||
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
|
||||
"""
|
||||
Apply mask to frame to create matted output
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3)
|
||||
mask: Binary mask (H, W)
|
||||
output_format: "alpha" or "greenscreen"
|
||||
background_color: RGB background color for greenscreen
|
||||
|
||||
Returns:
|
||||
Matted frame
|
||||
"""
|
||||
if mask is None:
|
||||
return frame
|
||||
|
||||
# Ensure mask is 2D
|
||||
if mask.ndim == 3:
|
||||
mask = mask.squeeze()
|
||||
|
||||
# Resize mask to match frame if needed
|
||||
if mask.shape[:2] != frame.shape[:2]:
|
||||
mask = cv2.resize(mask.astype(np.uint8),
|
||||
(frame.shape[1], frame.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST).astype(bool)
|
||||
|
||||
if output_format == "alpha":
|
||||
# Create RGBA output
|
||||
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
||||
output[:, :, :3] = frame
|
||||
output[:, :, 3] = mask.astype(np.uint8) * 255
|
||||
return output
|
||||
|
||||
elif output_format == "greenscreen":
|
||||
# Create RGB output with background
|
||||
output = np.full_like(frame, background_color, dtype=np.uint8)
|
||||
output[mask] = frame[mask]
|
||||
return output
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
if self.inference_state is not None:
|
||||
try:
|
||||
if hasattr(self.predictor, 'cleanup_state'):
|
||||
self.predictor.cleanup_state(self.inference_state)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
|
||||
|
||||
self.inference_state = None
|
||||
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup"""
|
||||
self.cleanup()
|
||||
Reference in New Issue
Block a user