working still
This commit is contained in:
39
main.py
39
main.py
@@ -507,9 +507,9 @@ def main():
|
||||
logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}")
|
||||
logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}")
|
||||
|
||||
# Handle mid-segment detection if enabled (only when using YOLO prompts, not masks)
|
||||
# Handle mid-segment detection if enabled (works for both detection and segmentation modes)
|
||||
multi_frame_prompts = None
|
||||
if config.get('advanced.enable_mid_segment_detection', False) and yolo_prompts:
|
||||
if config.get('advanced.enable_mid_segment_detection', False) and (yolo_prompts or has_yolo_masks):
|
||||
logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}")
|
||||
|
||||
# Calculate frame indices for re-detection
|
||||
@@ -538,23 +538,48 @@ def main():
|
||||
scale=config.get_inference_scale()
|
||||
)
|
||||
|
||||
# Convert detections to SAM2 prompts
|
||||
# Convert detections to SAM2 prompts (different handling for segmentation vs detection mode)
|
||||
multi_frame_prompts = {}
|
||||
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
for frame_idx, detections in multi_frame_detections.items():
|
||||
if detections:
|
||||
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
|
||||
multi_frame_prompts[frame_idx] = prompts
|
||||
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
|
||||
if has_yolo_masks:
|
||||
# Segmentation mode: convert YOLO masks to SAM2 mask prompts
|
||||
frame_masks = {}
|
||||
for i, detection in enumerate(detections[:2]): # Up to 2 objects
|
||||
if detection.get('has_mask', False):
|
||||
mask = detection['mask']
|
||||
# Resize mask to match inference scale
|
||||
if config.get_inference_scale() != 1.0:
|
||||
scale = config.get_inference_scale()
|
||||
scaled_height = int(frame_height * scale)
|
||||
scaled_width = int(frame_width * scale)
|
||||
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
|
||||
mask = mask > 0.5
|
||||
|
||||
obj_id = i + 1 # Sequential object IDs
|
||||
frame_masks[obj_id] = mask.astype(bool)
|
||||
logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
|
||||
|
||||
if frame_masks:
|
||||
# Store as mask prompts (different format than bbox prompts)
|
||||
multi_frame_prompts[frame_idx] = {'masks': frame_masks}
|
||||
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(frame_masks)} YOLO masks")
|
||||
else:
|
||||
# Detection mode: convert to bounding box prompts (existing logic)
|
||||
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
|
||||
multi_frame_prompts[frame_idx] = prompts
|
||||
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
|
||||
|
||||
logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames")
|
||||
else:
|
||||
logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)")
|
||||
elif config.get('advanced.enable_mid_segment_detection', False):
|
||||
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO prompts)")
|
||||
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO data)")
|
||||
|
||||
# Process segment with SAM2
|
||||
logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")
|
||||
|
||||
Reference in New Issue
Block a user