fix again
This commit is contained in:
@@ -5,6 +5,8 @@ import cv2
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
@@ -33,6 +35,7 @@ class SAM2VideoMatting:
|
||||
self.predictor = None
|
||||
self.inference_state = None
|
||||
self.video_segments = {}
|
||||
self.temp_video_path = None
|
||||
|
||||
self._load_model(model_cfg, checkpoint_path)
|
||||
|
||||
@@ -68,18 +71,46 @@ class SAM2VideoMatting:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||
|
||||
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
|
||||
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||
"""Initialize video inference state"""
|
||||
if self.predictor is None:
|
||||
raise RuntimeError("SAM2 model not loaded")
|
||||
|
||||
# Create temporary directory for frames if needed
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=None,
|
||||
video_frames=video_frames,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
async_loading_frames=True
|
||||
)
|
||||
if video_path is not None:
|
||||
# Use video path directly (SAM2's preferred method)
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=video_path,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
async_loading_frames=True
|
||||
)
|
||||
else:
|
||||
# For frame arrays, we need to save them as a temporary video first
|
||||
|
||||
if video_frames is None or len(video_frames) == 0:
|
||||
raise ValueError("Either video_path or video_frames must be provided")
|
||||
|
||||
# Create temporary video file
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
temp_video_path = Path(temp_dir) / "temp_video.mp4"
|
||||
|
||||
# Write frames to temporary video
|
||||
height, width = video_frames[0].shape[:2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
|
||||
|
||||
for frame in video_frames:
|
||||
writer.write(frame)
|
||||
writer.release()
|
||||
|
||||
# Initialize with temporary video
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=str(temp_video_path),
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
async_loading_frames=True
|
||||
)
|
||||
|
||||
# Store temp path for cleanup
|
||||
self.temp_video_path = temp_video_path
|
||||
|
||||
def add_person_prompts(self,
|
||||
frame_idx: int,
|
||||
@@ -231,6 +262,16 @@ class SAM2VideoMatting:
|
||||
|
||||
self.inference_state = None
|
||||
|
||||
# Clean up temporary video file
|
||||
if self.temp_video_path is not None:
|
||||
try:
|
||||
if self.temp_video_path.exists():
|
||||
# Remove the temporary directory
|
||||
shutil.rmtree(self.temp_video_path.parent)
|
||||
self.temp_video_path = None
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to cleanup temp video: {e}")
|
||||
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user