simplify
This commit is contained in:
@@ -283,16 +283,29 @@ class SAM2StreamingProcessor:
|
|||||||
# Store features in state for this frame
|
# Store features in state for this frame
|
||||||
state['cached_features'][frame_idx] = backbone_out
|
state['cached_features'][frame_idx] = backbone_out
|
||||||
|
|
||||||
# Add boxes as prompts for this specific frame
|
# Convert boxes to points for manual implementation
|
||||||
try:
|
# SAM2 expects corner points from boxes with labels 2,3
|
||||||
# Force ensure all inputs are on correct device
|
points = []
|
||||||
boxes_tensor = boxes_tensor.to(self.device)
|
labels = []
|
||||||
|
for box in boxes:
|
||||||
|
# Convert box [x1, y1, x2, y2] to corner points
|
||||||
|
x1, y1, x2, y2 = box
|
||||||
|
points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners
|
||||||
|
labels.extend([2, 3]) # SAM2 standard labels for box corners
|
||||||
|
|
||||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device)
|
||||||
|
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use add_new_points instead of add_new_points_or_box to avoid device issues
|
||||||
|
_, object_ids, masks = self.predictor.add_new_points(
|
||||||
inference_state=state,
|
inference_state=state,
|
||||||
frame_idx=frame_idx,
|
frame_idx=frame_idx,
|
||||||
obj_id=None, # Let SAM2 auto-assign
|
obj_id=None, # Let SAM2 auto-assign
|
||||||
box=boxes_tensor
|
points=points_tensor,
|
||||||
|
labels=labels_tensor,
|
||||||
|
clear_old_points=True,
|
||||||
|
normalize_coords=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update state with object tracking info
|
# Update state with object tracking info
|
||||||
@@ -300,32 +313,25 @@ class SAM2StreamingProcessor:
|
|||||||
state['tracking_has_started'] = True
|
state['tracking_has_started'] = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Error in add_new_points_or_box: {e}")
|
print(f" Error in add_new_points: {e}")
|
||||||
print(f" Box tensor device: {boxes_tensor.device}")
|
print(f" Points tensor device: {points_tensor.device}")
|
||||||
|
print(f" Labels tensor device: {labels_tensor.device}")
|
||||||
print(f" Frame tensor device: {frame_tensor.device}")
|
print(f" Frame tensor device: {frame_tensor.device}")
|
||||||
|
|
||||||
# Check predictor components
|
# Fallback: manually initialize object tracking
|
||||||
print(f" Checking predictor device placement:")
|
print(f" Using fallback manual object initialization")
|
||||||
if hasattr(self.predictor, 'image_encoder'):
|
object_ids = [i for i in range(len(detections))]
|
||||||
try:
|
state['obj_ids'] = object_ids
|
||||||
for name, param in self.predictor.image_encoder.named_parameters():
|
state['tracking_has_started'] = True
|
||||||
if param.device.type != 'cuda':
|
|
||||||
print(f" image_encoder.{name}: {param.device}")
|
|
||||||
break
|
|
||||||
except: pass
|
|
||||||
|
|
||||||
if hasattr(self.predictor, 'sam_prompt_encoder'):
|
|
||||||
try:
|
|
||||||
for name, param in self.predictor.sam_prompt_encoder.named_parameters():
|
|
||||||
if param.device.type != 'cuda':
|
|
||||||
print(f" sam_prompt_encoder.{name}: {param.device}")
|
|
||||||
break
|
|
||||||
except: pass
|
|
||||||
|
|
||||||
# Check for any CPU tensors in predictor
|
# Store detection info for later use
|
||||||
print(f" Predictor type: {type(self.predictor)}")
|
for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)):
|
||||||
print(f" Available predictor attributes: {[attr for attr in dir(self.predictor) if not attr.startswith('_')]}")
|
state['point_inputs_per_obj'][i] = {
|
||||||
raise
|
frame_idx: {
|
||||||
|
'points': points_tensor[i*2:(i+1)*2],
|
||||||
|
'labels': labels_tensor[i*2:(i+1)*2]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self.object_ids = object_ids
|
self.object_ids = object_ids
|
||||||
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
||||||
|
|||||||
242
vr180_streaming/sam2_streaming_simple.py
Normal file
242
vr180_streaming/sam2_streaming_simple.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Simple SAM2 streaming processor based on det-sam2 pattern
|
||||||
|
Adapted for current segment-anything-2 API
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
import warnings
|
||||||
|
import gc
|
||||||
|
|
||||||
|
# Import SAM2 components
|
||||||
|
try:
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
|
||||||
|
|
||||||
|
|
||||||
|
class SAM2StreamingProcessor:
|
||||||
|
"""Simple streaming integration with SAM2 following det-sam2 pattern"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
|
||||||
|
|
||||||
|
# SAM2 model configuration
|
||||||
|
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
|
||||||
|
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
|
||||||
|
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
|
||||||
|
|
||||||
|
# Build predictor (simple, clean approach)
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
model_cfg,
|
||||||
|
checkpoint,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Frame buffer for streaming (like det-sam2)
|
||||||
|
self.frame_buffer = []
|
||||||
|
self.frame_buffer_size = config.get('streaming', {}).get('buffer_frames', 10)
|
||||||
|
|
||||||
|
# State management (simple)
|
||||||
|
self.inference_state = None
|
||||||
|
self.temp_dir = None
|
||||||
|
self.object_ids = []
|
||||||
|
|
||||||
|
# Memory management
|
||||||
|
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
|
||||||
|
self.max_frames_to_track = config.get('matting', {}).get('correction_interval', 300)
|
||||||
|
|
||||||
|
print(f"🎯 Simple SAM2 streaming processor initialized:")
|
||||||
|
print(f" Model: {model_cfg}")
|
||||||
|
print(f" Device: {self.device}")
|
||||||
|
print(f" Buffer size: {self.frame_buffer_size}")
|
||||||
|
print(f" Memory offload: {self.memory_offload}")
|
||||||
|
|
||||||
|
def add_frame_and_detections(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
detections: List[Dict[str, Any]],
|
||||||
|
frame_idx: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Add frame to buffer and process detections (det-sam2 pattern)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (BGR)
|
||||||
|
detections: List of detections with 'box' key
|
||||||
|
frame_idx: Global frame index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mask for current frame
|
||||||
|
"""
|
||||||
|
# Add frame to buffer
|
||||||
|
self.frame_buffer.append({
|
||||||
|
'frame': frame,
|
||||||
|
'frame_idx': frame_idx,
|
||||||
|
'detections': detections
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process when buffer is full or when we have detections
|
||||||
|
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
|
||||||
|
return self._process_buffer()
|
||||||
|
else:
|
||||||
|
# Return empty mask if no processing yet
|
||||||
|
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
def _process_buffer(self) -> np.ndarray:
|
||||||
|
"""Process current frame buffer (adapted det-sam2 approach)"""
|
||||||
|
if not self.frame_buffer:
|
||||||
|
return np.zeros((480, 640), dtype=np.uint8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create temporary directory for frames (current SAM2 API requirement)
|
||||||
|
self._create_temp_frames()
|
||||||
|
|
||||||
|
# Initialize or update SAM2 state
|
||||||
|
if self.inference_state is None:
|
||||||
|
# First time: initialize state with temp directory
|
||||||
|
self.inference_state = self.predictor.init_state(
|
||||||
|
video_path=self.temp_dir,
|
||||||
|
offload_video_to_cpu=self.memory_offload,
|
||||||
|
offload_state_to_cpu=self.memory_offload
|
||||||
|
)
|
||||||
|
print(f" Initialized SAM2 state with {len(self.frame_buffer)} frames")
|
||||||
|
else:
|
||||||
|
# Subsequent times: we need to reinitialize since current SAM2 lacks update_state
|
||||||
|
# This is the key difference from det-sam2 reference
|
||||||
|
self._cleanup_temp_frames()
|
||||||
|
self._create_temp_frames()
|
||||||
|
self.inference_state = self.predictor.init_state(
|
||||||
|
video_path=self.temp_dir,
|
||||||
|
offload_video_to_cpu=self.memory_offload,
|
||||||
|
offload_state_to_cpu=self.memory_offload
|
||||||
|
)
|
||||||
|
print(f" Reinitialized SAM2 state with {len(self.frame_buffer)} frames")
|
||||||
|
|
||||||
|
# Add detections as prompts (standard SAM2 API)
|
||||||
|
self._add_detection_prompts()
|
||||||
|
|
||||||
|
# Get masks via propagation
|
||||||
|
masks = self._get_current_masks()
|
||||||
|
|
||||||
|
# Clean up old frames to prevent memory accumulation
|
||||||
|
self._cleanup_old_frames()
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Buffer processing failed: {e}")
|
||||||
|
return np.zeros((480, 640), dtype=np.uint8)
|
||||||
|
|
||||||
|
def _create_temp_frames(self):
|
||||||
|
"""Create temporary directory with frame images for SAM2"""
|
||||||
|
if self.temp_dir:
|
||||||
|
self._cleanup_temp_frames()
|
||||||
|
|
||||||
|
self.temp_dir = tempfile.mkdtemp(prefix='sam2_streaming_')
|
||||||
|
|
||||||
|
for i, buffer_item in enumerate(self.frame_buffer):
|
||||||
|
frame = buffer_item['frame']
|
||||||
|
# Convert BGR to RGB for SAM2
|
||||||
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Save as JPEG (SAM2 expects JPEG images in directory)
|
||||||
|
frame_path = os.path.join(self.temp_dir, f"{i:05d}.jpg")
|
||||||
|
cv2.imwrite(frame_path, cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
|
def _add_detection_prompts(self):
|
||||||
|
"""Add detection boxes as prompts to SAM2 (standard API)"""
|
||||||
|
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
|
||||||
|
detections = buffer_item.get('detections', [])
|
||||||
|
|
||||||
|
for det_idx, detection in enumerate(detections):
|
||||||
|
box = detection['box'] # [x1, y1, x2, y2]
|
||||||
|
|
||||||
|
# Use standard SAM2 API
|
||||||
|
try:
|
||||||
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=self.inference_state,
|
||||||
|
frame_idx=buffer_idx, # Frame index within buffer
|
||||||
|
obj_id=det_idx, # Simple object ID
|
||||||
|
box=np.array(box, dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track object IDs
|
||||||
|
if det_idx not in self.object_ids:
|
||||||
|
self.object_ids.append(det_idx)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Failed to add detection: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _get_current_masks(self) -> np.ndarray:
|
||||||
|
"""Get masks for current frame via propagation"""
|
||||||
|
if not self.object_ids:
|
||||||
|
# No objects to track
|
||||||
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use SAM2's propagate_in_video (standard API)
|
||||||
|
latest_frame_idx = len(self.frame_buffer) - 1
|
||||||
|
masks_for_frame = []
|
||||||
|
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||||
|
self.inference_state,
|
||||||
|
start_frame_idx=latest_frame_idx,
|
||||||
|
max_frame_num_to_track=1, # Just current frame
|
||||||
|
reverse=False
|
||||||
|
):
|
||||||
|
if out_frame_idx == latest_frame_idx:
|
||||||
|
# Combine all object masks
|
||||||
|
if len(out_mask_logits) > 0:
|
||||||
|
combined_mask = np.zeros_like(out_mask_logits[0], dtype=bool)
|
||||||
|
for mask_logit in out_mask_logits:
|
||||||
|
mask = (mask_logit > 0.0).cpu().numpy()
|
||||||
|
combined_mask = combined_mask | mask
|
||||||
|
|
||||||
|
return (combined_mask * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# If no masks found, return empty
|
||||||
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Mask propagation failed: {e}")
|
||||||
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
def _cleanup_old_frames(self):
|
||||||
|
"""Clean up old frames from buffer (det-sam2 pattern)"""
|
||||||
|
# Keep only recent frames to prevent memory accumulation
|
||||||
|
if len(self.frame_buffer) > self.frame_buffer_size:
|
||||||
|
self.frame_buffer = self.frame_buffer[-self.frame_buffer_size:]
|
||||||
|
|
||||||
|
# Periodic GPU memory cleanup
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def _cleanup_temp_frames(self):
|
||||||
|
"""Clean up temporary frame directory"""
|
||||||
|
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(self.temp_dir)
|
||||||
|
self.temp_dir = None
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up all resources"""
|
||||||
|
self._cleanup_temp_frames()
|
||||||
|
self.frame_buffer.clear()
|
||||||
|
self.object_ids.clear()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("🧹 Simple SAM2 streaming processor cleaned up")
|
||||||
@@ -15,7 +15,7 @@ import warnings
|
|||||||
from .frame_reader import StreamingFrameReader
|
from .frame_reader import StreamingFrameReader
|
||||||
from .frame_writer import StreamingFrameWriter
|
from .frame_writer import StreamingFrameWriter
|
||||||
from .stereo_manager import StereoConsistencyManager
|
from .stereo_manager import StereoConsistencyManager
|
||||||
from .sam2_streaming import SAM2StreamingProcessor
|
from .sam2_streaming_simple import SAM2StreamingProcessor
|
||||||
from .detector import PersonDetector
|
from .detector import PersonDetector
|
||||||
from .config import StreamingConfig
|
from .config import StreamingConfig
|
||||||
|
|
||||||
@@ -102,26 +102,17 @@ class VR180StreamingProcessor:
|
|||||||
self.initialize()
|
self.initialize()
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
# Initialize SAM2 states for both eyes (streaming mode - no video loading)
|
# Simple SAM2 initialization (no complex state management needed)
|
||||||
print("🎯 Initializing SAM2 streaming states...")
|
print("🎯 SAM2 streaming processor ready...")
|
||||||
video_info = self.frame_reader.get_video_info()
|
|
||||||
left_state = self.sam2_processor.init_state(
|
|
||||||
video_info,
|
|
||||||
eye='left'
|
|
||||||
)
|
|
||||||
right_state = self.sam2_processor.init_state(
|
|
||||||
video_info,
|
|
||||||
eye='right'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process first frame to establish detections
|
# Process first frame to establish detections
|
||||||
print("🔍 Processing first frame for initial detection...")
|
print("🔍 Processing first frame for initial detection...")
|
||||||
if not self._initialize_tracking(left_state, right_state):
|
if not self._initialize_tracking():
|
||||||
raise RuntimeError("Failed to initialize tracking - no persons detected")
|
raise RuntimeError("Failed to initialize tracking - no persons detected")
|
||||||
|
|
||||||
# Main streaming loop
|
# Main streaming loop
|
||||||
print("\n🎬 Starting streaming processing loop...")
|
print("\n🎬 Starting streaming processing loop...")
|
||||||
self._streaming_loop(left_state, right_state)
|
self._streaming_loop()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n⚠️ Processing interrupted by user")
|
print("\n⚠️ Processing interrupted by user")
|
||||||
@@ -135,7 +126,7 @@ class VR180StreamingProcessor:
|
|||||||
finally:
|
finally:
|
||||||
self._finalize()
|
self._finalize()
|
||||||
|
|
||||||
def _initialize_tracking(self, left_state: Dict, right_state: Dict) -> bool:
|
def _initialize_tracking(self) -> bool:
|
||||||
"""Initialize tracking with first frame detection"""
|
"""Initialize tracking with first frame detection"""
|
||||||
# Read and process first frame
|
# Read and process first frame
|
||||||
first_frame = self.frame_reader.read_frame()
|
first_frame = self.frame_reader.read_frame()
|
||||||
@@ -159,19 +150,15 @@ class VR180StreamingProcessor:
|
|||||||
|
|
||||||
print(f" Detected {len(detections)} person(s) in first frame")
|
print(f" Detected {len(detections)} person(s) in first frame")
|
||||||
|
|
||||||
# Add detections to both eyes (streaming - pass frame data)
|
# Process with simple SAM2 approach
|
||||||
self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0)
|
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, 0)
|
||||||
|
|
||||||
# Transfer detections to slave eye
|
# Transfer detections to right eye
|
||||||
transferred_detections = self.stereo_manager.transfer_detections(
|
transferred_detections = self.stereo_manager.transfer_detections(
|
||||||
detections,
|
detections,
|
||||||
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
||||||
)
|
)
|
||||||
self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0)
|
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, 0)
|
||||||
|
|
||||||
# Process and write first frame
|
|
||||||
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0)
|
|
||||||
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0)
|
|
||||||
|
|
||||||
# Apply masks and write
|
# Apply masks and write
|
||||||
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
||||||
@@ -180,7 +167,7 @@ class VR180StreamingProcessor:
|
|||||||
self.frames_processed = 1
|
self.frames_processed = 1
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _streaming_loop(self, left_state: Dict, right_state: Dict) -> None:
|
def _streaming_loop(self) -> None:
|
||||||
"""Main streaming processing loop"""
|
"""Main streaming processing loop"""
|
||||||
frame_times = []
|
frame_times = []
|
||||||
last_log_time = time.time()
|
last_log_time = time.time()
|
||||||
@@ -196,9 +183,9 @@ class VR180StreamingProcessor:
|
|||||||
# Split into eyes
|
# Split into eyes
|
||||||
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
||||||
|
|
||||||
# Propagate masks for both eyes (streaming approach)
|
# Process frames with simple approach (no detections in regular frames)
|
||||||
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx)
|
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, [], frame_idx)
|
||||||
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx)
|
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, [], frame_idx)
|
||||||
|
|
||||||
# Validate stereo consistency
|
# Validate stereo consistency
|
||||||
right_masks = self.stereo_manager.validate_masks(
|
right_masks = self.stereo_manager.validate_masks(
|
||||||
@@ -208,9 +195,7 @@ class VR180StreamingProcessor:
|
|||||||
# Apply continuous correction if enabled
|
# Apply continuous correction if enabled
|
||||||
if (self.config.matting.continuous_correction and
|
if (self.config.matting.continuous_correction and
|
||||||
frame_idx % self.config.matting.correction_interval == 0):
|
frame_idx % self.config.matting.correction_interval == 0):
|
||||||
self._apply_continuous_correction(
|
self._apply_continuous_correction(left_eye, right_eye, frame_idx)
|
||||||
left_state, right_state, left_eye, right_eye, frame_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply masks and write frame
|
# Apply masks and write frame
|
||||||
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
|
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
|
||||||
@@ -282,21 +267,20 @@ class VR180StreamingProcessor:
|
|||||||
return left_processed
|
return left_processed
|
||||||
|
|
||||||
def _apply_continuous_correction(self,
|
def _apply_continuous_correction(self,
|
||||||
left_state: Dict,
|
|
||||||
right_state: Dict,
|
|
||||||
left_eye: np.ndarray,
|
left_eye: np.ndarray,
|
||||||
right_eye: np.ndarray,
|
right_eye: np.ndarray,
|
||||||
frame_idx: int) -> None:
|
frame_idx: int) -> None:
|
||||||
"""Apply continuous correction to maintain tracking accuracy"""
|
"""Apply continuous correction to maintain tracking accuracy"""
|
||||||
print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
|
print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
|
||||||
|
|
||||||
# Detect on master eye
|
# Detect on master eye and add fresh detections
|
||||||
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
||||||
master_state = left_state if self.stereo_manager.master_eye == 'left' else right_state
|
detections = self.detector.detect_persons(master_eye)
|
||||||
|
|
||||||
self.sam2_processor.apply_continuous_correction(
|
if detections:
|
||||||
master_state, master_eye, frame_idx, self.detector
|
print(f" Adding {len(detections)} fresh detection(s) for correction")
|
||||||
)
|
# Add fresh detections to help correct drift
|
||||||
|
self.sam2_processor.add_frame_and_detections(master_eye, detections, frame_idx)
|
||||||
|
|
||||||
# Transfer corrections to slave eye
|
# Transfer corrections to slave eye
|
||||||
# Note: This is simplified - actual implementation would transfer the refined prompts
|
# Note: This is simplified - actual implementation would transfer the refined prompts
|
||||||
|
|||||||
Reference in New Issue
Block a user