diff --git a/core/sam2_processor.py b/core/sam2_processor.py index ab840a6..39c06fe 100644 --- a/core/sam2_processor.py +++ b/core/sam2_processor.py @@ -65,7 +65,8 @@ class SAM2Processor: self.predictor = build_sam2_video_predictor( config_name, # Use just the config name, not full path self.checkpoint_path, - device=device + device=device, + overrides=dict(conf=0.95) ) # Enable optimizations for CUDA @@ -539,13 +540,14 @@ class SAM2Processor: logger.error(f"Error generating first frame debug masks: {e}") return False - def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, List[Dict[str, Any]]]) -> bool: + def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool: """ - Add YOLO detection prompts at multiple frame indices for mid-segment re-detection. + Add YOLO prompts at multiple frame indices for mid-segment re-detection. + Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode). Args: inference_state: SAM2 inference state - multi_frame_prompts: Dictionary mapping frame_index -> list of prompt dictionaries + multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation) Returns: True if prompts were added successfully @@ -554,41 +556,60 @@ class SAM2Processor: logger.warning("SAM2 Mid-segment: No multi-frame prompts provided") return False - total_prompts = sum(len(prompts) for prompts in multi_frame_prompts.values()) - logger.info(f"SAM2 Mid-segment: Adding {total_prompts} prompts across {len(multi_frame_prompts)} frames") - success_count = 0 total_count = 0 - for frame_idx, prompts in multi_frame_prompts.items(): - logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} prompts") - - for i, prompt in enumerate(prompts): - obj_id = prompt['obj_id'] - bbox = prompt['bbox'] - confidence = prompt.get('confidence', 'unknown') - total_count += 1 + for frame_idx, prompts_data in multi_frame_prompts.items(): + # Check if this is segmentation mode (masks) or detection mode (bbox prompts) + if isinstance(prompts_data, dict) and 'masks' in prompts_data: + # Segmentation mode: add masks directly + masks_dict = prompts_data['masks'] + logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks") - logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}") + for obj_id, mask in masks_dict.items(): + total_count += 1 + logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}") + + try: + self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask) + logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully") + success_count += 1 + + except Exception as e: + logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}") + continue + + else: + # Detection mode: add bounding box prompts (existing logic) + prompts = prompts_data + logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts") - try: - _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( - inference_state=inference_state, - frame_idx=frame_idx, # Key: specify the exact frame index - obj_id=obj_id, - box=bbox.astype(np.float32), - ) + for i, prompt in enumerate(prompts): + obj_id = prompt['obj_id'] + bbox = prompt['bbox'] + confidence = prompt.get('confidence', 'unknown') + total_count += 1 - logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}") - success_count += 1 + logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}") - except Exception as e: - logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}") - continue + try: + _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=frame_idx, # Key: specify the exact frame index + obj_id=obj_id, + box=bbox.astype(np.float32), + ) + + logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}") + success_count += 1 + + except Exception as e: + logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}") + continue if success_count > 0: logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames") return True else: logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added") - return False \ No newline at end of file + return False diff --git a/main.py b/main.py index 0f02e93..b0dffe1 100644 --- a/main.py +++ b/main.py @@ -507,9 +507,9 @@ def main(): logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}") logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}") - # Handle mid-segment detection if enabled (only when using YOLO prompts, not masks) + # Handle mid-segment detection if enabled (works for both detection and segmentation modes) multi_frame_prompts = None - if config.get('advanced.enable_mid_segment_detection', False) and yolo_prompts: + if config.get('advanced.enable_mid_segment_detection', False) and (yolo_prompts or has_yolo_masks): logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}") # Calculate frame indices for re-detection @@ -538,23 +538,48 @@ def main(): scale=config.get_inference_scale() ) - # Convert detections to SAM2 prompts + # Convert detections to SAM2 prompts (different handling for segmentation vs detection mode) multi_frame_prompts = {} cap = cv2.VideoCapture(segment_info['video_file']) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() for frame_idx, detections in multi_frame_detections.items(): if detections: - prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width) - multi_frame_prompts[frame_idx] = prompts - logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts") + if has_yolo_masks: + # Segmentation mode: convert YOLO masks to SAM2 mask prompts + frame_masks = {} + for i, detection in enumerate(detections[:2]): # Up to 2 objects + if detection.get('has_mask', False): + mask = detection['mask'] + # Resize mask to match inference scale + if config.get_inference_scale() != 1.0: + scale = config.get_inference_scale() + scaled_height = int(frame_height * scale) + scaled_width = int(frame_width * scale) + mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST) + mask = mask > 0.5 + + obj_id = i + 1 # Sequential object IDs + frame_masks[obj_id] = mask.astype(bool) + logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}") + + if frame_masks: + # Store as mask prompts (different format than bbox prompts) + multi_frame_prompts[frame_idx] = {'masks': frame_masks} + logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(frame_masks)} YOLO masks") + else: + # Detection mode: convert to bounding box prompts (existing logic) + prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width) + multi_frame_prompts[frame_idx] = prompts + logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts") logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames") else: logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)") elif config.get('advanced.enable_mid_segment_detection', False): - logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO prompts)") + logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO data)") # Process segment with SAM2 logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")