fix again

This commit is contained in:
2025-07-26 08:54:03 -07:00
parent d933d6b606
commit 3e21fd8678

View File

@@ -5,6 +5,8 @@ import cv2
from pathlib import Path from pathlib import Path
import warnings import warnings
import os import os
import tempfile
import shutil
try: try:
from sam2.build_sam import build_sam2_video_predictor from sam2.build_sam import build_sam2_video_predictor
@@ -33,6 +35,7 @@ class SAM2VideoMatting:
self.predictor = None self.predictor = None
self.inference_state = None self.inference_state = None
self.video_segments = {} self.video_segments = {}
self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path) self._load_model(model_cfg, checkpoint_path)
@@ -68,18 +71,46 @@ class SAM2VideoMatting:
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}") raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray]) -> None: def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state""" """Initialize video inference state"""
if self.predictor is None: if self.predictor is None:
raise RuntimeError("SAM2 model not loaded") raise RuntimeError("SAM2 model not loaded")
# Create temporary directory for frames if needed if video_path is not None:
# Use video path directly (SAM2's preferred method)
self.inference_state = self.predictor.init_state( self.inference_state = self.predictor.init_state(
video_path=None, video_path=video_path,
video_frames=video_frames,
offload_video_to_cpu=self.memory_offload, offload_video_to_cpu=self.memory_offload,
async_loading_frames=True async_loading_frames=True
) )
else:
# For frame arrays, we need to save them as a temporary video first
if video_frames is None or len(video_frames) == 0:
raise ValueError("Either video_path or video_frames must be provided")
# Create temporary video file
temp_dir = tempfile.mkdtemp()
temp_video_path = Path(temp_dir) / "temp_video.mp4"
# Write frames to temporary video
height, width = video_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
for frame in video_frames:
writer.write(frame)
writer.release()
# Initialize with temporary video
self.inference_state = self.predictor.init_state(
video_path=str(temp_video_path),
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Store temp path for cleanup
self.temp_video_path = temp_video_path
def add_person_prompts(self, def add_person_prompts(self,
frame_idx: int, frame_idx: int,
@@ -231,6 +262,16 @@ class SAM2VideoMatting:
self.inference_state = None self.inference_state = None
# Clean up temporary video file
if self.temp_video_path is not None:
try:
if self.temp_video_path.exists():
# Remove the temporary directory
shutil.rmtree(self.temp_video_path.parent)
self.temp_video_path = None
except Exception as e:
warnings.warn(f"Failed to cleanup temp video: {e}")
# Clear CUDA cache # Clear CUDA cache
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()