Files
test2/vr180_matting/sam2_wrapper.py
2025-07-26 13:21:39 -07:00

307 lines
12 KiB
Python

import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import cv2
from pathlib import Path
import warnings
import os
import tempfile
import shutil
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
except ImportError:
SAM2_AVAILABLE = False
warnings.warn("SAM2 not available. Please install sam2 package.")
class SAM2VideoMatting:
"""SAM2-based video matting with memory optimization"""
def __init__(self,
model_cfg: str = "sam2_hiera_l",
checkpoint_path: str = "sam2_hiera_large.pt",
device: str = "cuda",
memory_offload: bool = True,
fp16: bool = True):
if not SAM2_AVAILABLE:
raise ImportError("SAM2 not available. Please install sam2 package.")
self.device = device
self.memory_offload = memory_offload
self.fp16 = fp16
self.predictor = None
self.inference_state = None
self.video_segments = {}
self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path)
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations"""
try:
# Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
if sam2_path.exists():
checkpoint_path = str(sam2_path)
else:
# Try legacy models/ directory
models_path = Path("models") / Path(checkpoint_path).name
if models_path.exists():
checkpoint_path = str(models_path)
else:
# Try relative to package
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path)
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor(
config_file=model_cfg,
ckpt_path=checkpoint_path,
device=self.device
)
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state"""
if self.predictor is None:
raise RuntimeError("SAM2 model not loaded")
if video_path is not None:
# Use video path directly (SAM2's preferred method)
self.inference_state = self.predictor.init_state(
video_path=video_path,
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
else:
# For frame arrays, we need to save them as a temporary video first
if video_frames is None or len(video_frames) == 0:
raise ValueError("Either video_path or video_frames must be provided")
# Create temporary video file in current directory
import uuid
temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4"
temp_video_path = Path.cwd() / temp_video_name
# Write frames to temporary video
height, width = video_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
for frame in video_frames:
writer.write(frame)
writer.release()
# Initialize with temporary video
self.inference_state = self.predictor.init_state(
video_path=str(temp_video_path),
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Store temp path for cleanup
self.temp_video_path = temp_video_path
def add_person_prompts(self,
frame_idx: int,
box_prompts: np.ndarray,
labels: np.ndarray) -> List[int]:
"""
Add person detection prompts to SAM2
Args:
frame_idx: Frame index to add prompts
box_prompts: Bounding boxes (N, 4)
labels: Prompt labels (N,)
Returns:
List of object IDs
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
object_ids = []
for i, (box, label) in enumerate(zip(box_prompts, labels)):
obj_id = i + 1 # Start from 1
# Add box prompt
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
box=box,
)
object_ids.extend(out_obj_ids)
return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
"""
Propagate masks through video
Args:
start_frame: Starting frame index
max_frames: Maximum number of frames to process
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically
if self.memory_offload and out_frame_idx % 100 == 0:
self._release_old_frames(out_frame_idx - 50)
return video_segments
def _release_old_frames(self, before_frame_idx: int):
"""Release old frames from memory"""
try:
if hasattr(self.predictor, 'release_old_frames'):
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
except Exception as e:
warnings.warn(f"Failed to release old frames: {e}")
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
"""Combine masks from multiple objects into single mask"""
if not frame_masks:
return None
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
for obj_id, mask in frame_masks.items():
if mask.ndim == 3:
mask = mask.squeeze()
combined_mask = np.logical_or(combined_mask, mask)
return combined_mask
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = "alpha",
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
"""
Apply mask to frame to create matted output
Args:
frame: Input frame (H, W, 3)
mask: Binary mask (H, W)
output_format: "alpha" or "greenscreen"
background_color: RGB background color for greenscreen
Returns:
Matted frame
"""
if mask is None:
return frame
# Ensure mask is 2D
if mask.ndim == 3:
mask = mask.squeeze()
# Resize mask to match frame if needed
if mask.shape[:2] != frame.shape[:2]:
mask = cv2.resize(mask.astype(np.uint8),
(frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_NEAREST).astype(bool)
if output_format == "alpha":
# Create RGBA output
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
output[:, :, :3] = frame
output[:, :, 3] = mask.astype(np.uint8) * 255
return output
elif output_format == "greenscreen":
# Create RGB output with background
output = np.full_like(frame, background_color, dtype=np.uint8)
output[mask] = frame[mask]
return output
else:
raise ValueError(f"Unsupported output format: {output_format}")
def cleanup(self):
"""Clean up resources"""
if self.inference_state is not None:
try:
# Reset SAM2 state first (critical for memory cleanup)
if hasattr(self.predictor, 'reset_state'):
self.predictor.reset_state(self.inference_state)
# Fallback to cleanup_state if available
elif hasattr(self.predictor, 'cleanup_state'):
self.predictor.cleanup_state(self.inference_state)
# Explicitly delete inference state and video segments
del self.inference_state
if hasattr(self, 'video_segments') and self.video_segments:
del self.video_segments
self.video_segments = {}
except Exception as e:
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
finally:
self.inference_state = None
# Explicitly delete predictor
if self.predictor is not None:
try:
del self.predictor
except Exception as e:
warnings.warn(f"Failed to delete predictor: {e}")
finally:
self.predictor = None
# Clean up temporary video file
if self.temp_video_path is not None:
try:
if self.temp_video_path.exists():
# Remove the temporary video file
self.temp_video_path.unlink()
self.temp_video_path = None
except Exception as e:
warnings.warn(f"Failed to cleanup temp video: {e}")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Force garbage collection (critical for memory leak prevention)
import gc
gc.collect()
def __del__(self):
"""Destructor to ensure cleanup"""
self.cleanup()