working still
This commit is contained in:
@@ -65,7 +65,8 @@ class SAM2Processor:
|
|||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
config_name, # Use just the config name, not full path
|
config_name, # Use just the config name, not full path
|
||||||
self.checkpoint_path,
|
self.checkpoint_path,
|
||||||
device=device
|
device=device,
|
||||||
|
overrides=dict(conf=0.95)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enable optimizations for CUDA
|
# Enable optimizations for CUDA
|
||||||
@@ -539,13 +540,14 @@ class SAM2Processor:
|
|||||||
logger.error(f"Error generating first frame debug masks: {e}")
|
logger.error(f"Error generating first frame debug masks: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, List[Dict[str, Any]]]) -> bool:
|
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
Add YOLO detection prompts at multiple frame indices for mid-segment re-detection.
|
Add YOLO prompts at multiple frame indices for mid-segment re-detection.
|
||||||
|
Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inference_state: SAM2 inference state
|
inference_state: SAM2 inference state
|
||||||
multi_frame_prompts: Dictionary mapping frame_index -> list of prompt dictionaries
|
multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if prompts were added successfully
|
True if prompts were added successfully
|
||||||
@@ -554,41 +556,60 @@ class SAM2Processor:
|
|||||||
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
|
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
total_prompts = sum(len(prompts) for prompts in multi_frame_prompts.values())
|
|
||||||
logger.info(f"SAM2 Mid-segment: Adding {total_prompts} prompts across {len(multi_frame_prompts)} frames")
|
|
||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
|
||||||
for frame_idx, prompts in multi_frame_prompts.items():
|
for frame_idx, prompts_data in multi_frame_prompts.items():
|
||||||
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} prompts")
|
# Check if this is segmentation mode (masks) or detection mode (bbox prompts)
|
||||||
|
if isinstance(prompts_data, dict) and 'masks' in prompts_data:
|
||||||
for i, prompt in enumerate(prompts):
|
# Segmentation mode: add masks directly
|
||||||
obj_id = prompt['obj_id']
|
masks_dict = prompts_data['masks']
|
||||||
bbox = prompt['bbox']
|
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks")
|
||||||
confidence = prompt.get('confidence', 'unknown')
|
|
||||||
total_count += 1
|
|
||||||
|
|
||||||
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
|
for obj_id, mask in masks_dict.items():
|
||||||
|
total_count += 1
|
||||||
|
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
|
||||||
|
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Detection mode: add bounding box prompts (existing logic)
|
||||||
|
prompts = prompts_data
|
||||||
|
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts")
|
||||||
|
|
||||||
try:
|
for i, prompt in enumerate(prompts):
|
||||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
obj_id = prompt['obj_id']
|
||||||
inference_state=inference_state,
|
bbox = prompt['bbox']
|
||||||
frame_idx=frame_idx, # Key: specify the exact frame index
|
confidence = prompt.get('confidence', 'unknown')
|
||||||
obj_id=obj_id,
|
total_count += 1
|
||||||
box=bbox.astype(np.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
|
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
|
||||||
success_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
try:
|
||||||
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||||
continue
|
inference_state=inference_state,
|
||||||
|
frame_idx=frame_idx, # Key: specify the exact frame index
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=bbox.astype(np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
|
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
|
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
|
||||||
return False
|
return False
|
||||||
|
|||||||
39
main.py
39
main.py
@@ -507,9 +507,9 @@ def main():
|
|||||||
logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}")
|
logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}")
|
||||||
logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}")
|
logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}")
|
||||||
|
|
||||||
# Handle mid-segment detection if enabled (only when using YOLO prompts, not masks)
|
# Handle mid-segment detection if enabled (works for both detection and segmentation modes)
|
||||||
multi_frame_prompts = None
|
multi_frame_prompts = None
|
||||||
if config.get('advanced.enable_mid_segment_detection', False) and yolo_prompts:
|
if config.get('advanced.enable_mid_segment_detection', False) and (yolo_prompts or has_yolo_masks):
|
||||||
logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}")
|
logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}")
|
||||||
|
|
||||||
# Calculate frame indices for re-detection
|
# Calculate frame indices for re-detection
|
||||||
@@ -538,23 +538,48 @@ def main():
|
|||||||
scale=config.get_inference_scale()
|
scale=config.get_inference_scale()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert detections to SAM2 prompts
|
# Convert detections to SAM2 prompts (different handling for segmentation vs detection mode)
|
||||||
multi_frame_prompts = {}
|
multi_frame_prompts = {}
|
||||||
cap = cv2.VideoCapture(segment_info['video_file'])
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
for frame_idx, detections in multi_frame_detections.items():
|
for frame_idx, detections in multi_frame_detections.items():
|
||||||
if detections:
|
if detections:
|
||||||
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
|
if has_yolo_masks:
|
||||||
multi_frame_prompts[frame_idx] = prompts
|
# Segmentation mode: convert YOLO masks to SAM2 mask prompts
|
||||||
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
|
frame_masks = {}
|
||||||
|
for i, detection in enumerate(detections[:2]): # Up to 2 objects
|
||||||
|
if detection.get('has_mask', False):
|
||||||
|
mask = detection['mask']
|
||||||
|
# Resize mask to match inference scale
|
||||||
|
if config.get_inference_scale() != 1.0:
|
||||||
|
scale = config.get_inference_scale()
|
||||||
|
scaled_height = int(frame_height * scale)
|
||||||
|
scaled_width = int(frame_width * scale)
|
||||||
|
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
obj_id = i + 1 # Sequential object IDs
|
||||||
|
frame_masks[obj_id] = mask.astype(bool)
|
||||||
|
logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
|
||||||
|
|
||||||
|
if frame_masks:
|
||||||
|
# Store as mask prompts (different format than bbox prompts)
|
||||||
|
multi_frame_prompts[frame_idx] = {'masks': frame_masks}
|
||||||
|
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(frame_masks)} YOLO masks")
|
||||||
|
else:
|
||||||
|
# Detection mode: convert to bounding box prompts (existing logic)
|
||||||
|
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
|
||||||
|
multi_frame_prompts[frame_idx] = prompts
|
||||||
|
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
|
||||||
|
|
||||||
logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames")
|
logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)")
|
logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)")
|
||||||
elif config.get('advanced.enable_mid_segment_detection', False):
|
elif config.get('advanced.enable_mid_segment_detection', False):
|
||||||
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO prompts)")
|
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO data)")
|
||||||
|
|
||||||
# Process segment with SAM2
|
# Process segment with SAM2
|
||||||
logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")
|
logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")
|
||||||
|
|||||||
Reference in New Issue
Block a user