245 lines
8.8 KiB
Python
245 lines
8.8 KiB
Python
import torch
|
|
import numpy as np
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import cv2
|
|
from pathlib import Path
|
|
import warnings
|
|
import os
|
|
|
|
try:
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
SAM2_AVAILABLE = True
|
|
except ImportError:
|
|
SAM2_AVAILABLE = False
|
|
warnings.warn("SAM2 not available. Please install sam2 package.")
|
|
|
|
|
|
class SAM2VideoMatting:
|
|
"""SAM2-based video matting with memory optimization"""
|
|
|
|
def __init__(self,
|
|
model_cfg: str = "sam2_hiera_l",
|
|
checkpoint_path: str = "sam2_hiera_large.pt",
|
|
device: str = "cuda",
|
|
memory_offload: bool = True,
|
|
fp16: bool = True):
|
|
if not SAM2_AVAILABLE:
|
|
raise ImportError("SAM2 not available. Please install sam2 package.")
|
|
|
|
self.device = device
|
|
self.memory_offload = memory_offload
|
|
self.fp16 = fp16
|
|
self.predictor = None
|
|
self.inference_state = None
|
|
self.video_segments = {}
|
|
|
|
self._load_model(model_cfg, checkpoint_path)
|
|
|
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
|
"""Load SAM2 video predictor with optimizations"""
|
|
try:
|
|
# Check for checkpoint in SAM2 repo structure
|
|
if not Path(checkpoint_path).exists():
|
|
# Try in segment-anything-2/checkpoints/
|
|
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
|
|
if sam2_path.exists():
|
|
checkpoint_path = str(sam2_path)
|
|
else:
|
|
# Try legacy models/ directory
|
|
models_path = Path("models") / Path(checkpoint_path).name
|
|
if models_path.exists():
|
|
checkpoint_path = str(models_path)
|
|
else:
|
|
# Try relative to package
|
|
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
|
|
if sam2_repo_path.exists():
|
|
checkpoint_path = str(sam2_repo_path)
|
|
|
|
self.predictor = build_sam2_video_predictor(
|
|
model_cfg,
|
|
checkpoint_path,
|
|
device=self.device
|
|
)
|
|
|
|
# Enable memory optimizations
|
|
if self.memory_offload:
|
|
self.predictor.fill_hole_area = 8
|
|
|
|
if self.fp16 and self.device == "cuda":
|
|
self.predictor.model.half()
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
|
|
|
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
|
|
"""Initialize video inference state"""
|
|
if self.predictor is None:
|
|
raise RuntimeError("SAM2 model not loaded")
|
|
|
|
# Create temporary directory for frames if needed
|
|
self.inference_state = self.predictor.init_state(
|
|
video_path=None,
|
|
video_frames=video_frames,
|
|
offload_video_to_cpu=self.memory_offload,
|
|
async_loading_frames=True
|
|
)
|
|
|
|
def add_person_prompts(self,
|
|
frame_idx: int,
|
|
box_prompts: np.ndarray,
|
|
labels: np.ndarray) -> List[int]:
|
|
"""
|
|
Add person detection prompts to SAM2
|
|
|
|
Args:
|
|
frame_idx: Frame index to add prompts
|
|
box_prompts: Bounding boxes (N, 4)
|
|
labels: Prompt labels (N,)
|
|
|
|
Returns:
|
|
List of object IDs
|
|
"""
|
|
if self.inference_state is None:
|
|
raise RuntimeError("Video state not initialized")
|
|
|
|
object_ids = []
|
|
|
|
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
|
obj_id = i + 1 # Start from 1
|
|
|
|
# Add box prompt
|
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
|
inference_state=self.inference_state,
|
|
frame_idx=frame_idx,
|
|
obj_id=obj_id,
|
|
box=box,
|
|
)
|
|
|
|
object_ids.extend(out_obj_ids)
|
|
|
|
return object_ids
|
|
|
|
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
|
"""
|
|
Propagate masks through video
|
|
|
|
Args:
|
|
start_frame: Starting frame index
|
|
max_frames: Maximum number of frames to process
|
|
|
|
Returns:
|
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
|
"""
|
|
if self.inference_state is None:
|
|
raise RuntimeError("Video state not initialized")
|
|
|
|
video_segments = {}
|
|
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
|
self.inference_state,
|
|
start_frame_idx=start_frame,
|
|
max_frame_num_to_track=max_frames,
|
|
reverse=False
|
|
):
|
|
frame_masks = {}
|
|
|
|
for i, out_obj_id in enumerate(out_obj_ids):
|
|
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
frame_masks[out_obj_id] = mask
|
|
|
|
video_segments[out_frame_idx] = frame_masks
|
|
|
|
# Memory management: release old frames periodically
|
|
if self.memory_offload and out_frame_idx % 100 == 0:
|
|
self._release_old_frames(out_frame_idx - 50)
|
|
|
|
return video_segments
|
|
|
|
def _release_old_frames(self, before_frame_idx: int):
|
|
"""Release old frames from memory"""
|
|
try:
|
|
if hasattr(self.predictor, 'release_old_frames'):
|
|
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to release old frames: {e}")
|
|
|
|
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
|
|
"""Combine masks from multiple objects into single mask"""
|
|
if not frame_masks:
|
|
return None
|
|
|
|
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
|
|
|
|
for obj_id, mask in frame_masks.items():
|
|
if mask.ndim == 3:
|
|
mask = mask.squeeze()
|
|
combined_mask = np.logical_or(combined_mask, mask)
|
|
|
|
return combined_mask
|
|
|
|
def apply_mask_to_frame(self,
|
|
frame: np.ndarray,
|
|
mask: np.ndarray,
|
|
output_format: str = "alpha",
|
|
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
|
|
"""
|
|
Apply mask to frame to create matted output
|
|
|
|
Args:
|
|
frame: Input frame (H, W, 3)
|
|
mask: Binary mask (H, W)
|
|
output_format: "alpha" or "greenscreen"
|
|
background_color: RGB background color for greenscreen
|
|
|
|
Returns:
|
|
Matted frame
|
|
"""
|
|
if mask is None:
|
|
return frame
|
|
|
|
# Ensure mask is 2D
|
|
if mask.ndim == 3:
|
|
mask = mask.squeeze()
|
|
|
|
# Resize mask to match frame if needed
|
|
if mask.shape[:2] != frame.shape[:2]:
|
|
mask = cv2.resize(mask.astype(np.uint8),
|
|
(frame.shape[1], frame.shape[0]),
|
|
interpolation=cv2.INTER_NEAREST).astype(bool)
|
|
|
|
if output_format == "alpha":
|
|
# Create RGBA output
|
|
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
|
output[:, :, :3] = frame
|
|
output[:, :, 3] = mask.astype(np.uint8) * 255
|
|
return output
|
|
|
|
elif output_format == "greenscreen":
|
|
# Create RGB output with background
|
|
output = np.full_like(frame, background_color, dtype=np.uint8)
|
|
output[mask] = frame[mask]
|
|
return output
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported output format: {output_format}")
|
|
|
|
def cleanup(self):
|
|
"""Clean up resources"""
|
|
if self.inference_state is not None:
|
|
try:
|
|
if hasattr(self.predictor, 'cleanup_state'):
|
|
self.predictor.cleanup_state(self.inference_state)
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
|
|
|
|
self.inference_state = None
|
|
|
|
# Clear CUDA cache
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
def __del__(self):
|
|
"""Destructor to ensure cleanup"""
|
|
self.cleanup() |