diff --git a/core/sam2_processor.py b/core/sam2_processor.py index 39c06fe..aed7aeb 100644 --- a/core/sam2_processor.py +++ b/core/sam2_processor.py @@ -17,16 +17,18 @@ logger = logging.getLogger(__name__) class SAM2Processor: """Handles SAM2-based video segmentation for human tracking.""" - def __init__(self, checkpoint_path: str, config_path: str): + def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False): """ Initialize SAM2 processor. Args: checkpoint_path: Path to SAM2 checkpoint config_path: Path to SAM2 config file + vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+) """ self.checkpoint_path = checkpoint_path self.config_path = config_path + self.vos_optimized = vos_optimized self.predictor = None self._initialize_predictor() @@ -62,12 +64,35 @@ class SAM2Processor: logger.info(f"Using SAM2 config: {config_name}") - self.predictor = build_sam2_video_predictor( - config_name, # Use just the config name, not full path - self.checkpoint_path, - device=device, - overrides=dict(conf=0.95) - ) + # Use VOS optimization if enabled and supported + if self.vos_optimized: + try: + self.predictor = build_sam2_video_predictor( + config_name, # Use just the config name, not full path + self.checkpoint_path, + device=device, + vos_optimized=True # New optimization for major speedup + ) + logger.info("Using optimized SAM2 VOS predictor with full model compilation") + except Exception as e: + logger.warning(f"Failed to use optimized VOS predictor: {e}") + logger.info("Falling back to standard SAM2 predictor") + # Fallback to standard predictor + self.predictor = build_sam2_video_predictor( + config_name, + self.checkpoint_path, + device=device, + overrides=dict(conf=0.95) + ) + else: + # Use standard predictor + self.predictor = build_sam2_video_predictor( + config_name, + self.checkpoint_path, + device=device, + overrides=dict(conf=0.95) + ) + logger.info("Using standard SAM2 predictor") # Enable optimizations for CUDA if device.type == "cuda": @@ -248,18 +273,30 @@ class SAM2Processor: Dictionary mapping frame indices to object masks """ video_segments = {} + frame_count = 0 try: + logger.info("Starting SAM2 mask propagation...") for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } + frame_count += 1 + + # Log progress every 50 frames + if frame_count % 50 == 0: + logger.info(f"SAM2 propagation progress: {frame_count} frames processed") - logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects") + logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects") except Exception as e: - logger.error(f"Error during mask propagation: {e}") + logger.error(f"Error during mask propagation after {frame_count} frames: {e}") + logger.error("This may be due to VOS optimization issues or insufficient GPU memory") + if frame_count == 0: + logger.error("No frames were processed - propagation failed completely") + else: + logger.warning(f"Partial propagation completed: {frame_count} frames before failure") return video_segments diff --git a/main.py b/main.py index b0dffe1..394efba 100644 --- a/main.py +++ b/main.py @@ -277,7 +277,8 @@ def main(): logger.info("Step 3: Initializing SAM2 processor") sam2_processor = SAM2Processor( checkpoint_path=config.get_sam2_checkpoint(), - config_path=config.get_sam2_config() + config_path=config.get_sam2_config(), + vos_optimized=config.get('models.sam2_vos_optimized', False) ) # Initialize mask processor @@ -595,6 +596,13 @@ def main(): logger.error(f"SAM2 processing failed for segment {segment_idx}") continue + # Check if SAM2 produced adequate results + if len(video_segments) == 0: + logger.error(f"SAM2 produced no frames for segment {segment_idx}") + continue + elif len(video_segments) < 10: # Expected many frames for a 5-second segment + logger.warning(f"SAM2 produced very few frames ({len(video_segments)}) for segment {segment_idx} - this may indicate propagation failure") + # Debug what SAM2 produced logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}") logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")