working still
This commit is contained in:
@@ -65,7 +65,8 @@ class SAM2Processor:
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
config_name, # Use just the config name, not full path
|
||||
self.checkpoint_path,
|
||||
device=device
|
||||
device=device,
|
||||
overrides=dict(conf=0.95)
|
||||
)
|
||||
|
||||
# Enable optimizations for CUDA
|
||||
@@ -539,13 +540,14 @@ class SAM2Processor:
|
||||
logger.error(f"Error generating first frame debug masks: {e}")
|
||||
return False
|
||||
|
||||
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, List[Dict[str, Any]]]) -> bool:
|
||||
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool:
|
||||
"""
|
||||
Add YOLO detection prompts at multiple frame indices for mid-segment re-detection.
|
||||
Add YOLO prompts at multiple frame indices for mid-segment re-detection.
|
||||
Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode).
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
multi_frame_prompts: Dictionary mapping frame_index -> list of prompt dictionaries
|
||||
multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation)
|
||||
|
||||
Returns:
|
||||
True if prompts were added successfully
|
||||
@@ -554,41 +556,60 @@ class SAM2Processor:
|
||||
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
|
||||
return False
|
||||
|
||||
total_prompts = sum(len(prompts) for prompts in multi_frame_prompts.values())
|
||||
logger.info(f"SAM2 Mid-segment: Adding {total_prompts} prompts across {len(multi_frame_prompts)} frames")
|
||||
|
||||
success_count = 0
|
||||
total_count = 0
|
||||
|
||||
for frame_idx, prompts in multi_frame_prompts.items():
|
||||
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} prompts")
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
obj_id = prompt['obj_id']
|
||||
bbox = prompt['bbox']
|
||||
confidence = prompt.get('confidence', 'unknown')
|
||||
total_count += 1
|
||||
for frame_idx, prompts_data in multi_frame_prompts.items():
|
||||
# Check if this is segmentation mode (masks) or detection mode (bbox prompts)
|
||||
if isinstance(prompts_data, dict) and 'masks' in prompts_data:
|
||||
# Segmentation mode: add masks directly
|
||||
masks_dict = prompts_data['masks']
|
||||
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks")
|
||||
|
||||
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
|
||||
for obj_id, mask in masks_dict.items():
|
||||
total_count += 1
|
||||
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}")
|
||||
|
||||
try:
|
||||
self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
|
||||
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully")
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Detection mode: add bounding box prompts (existing logic)
|
||||
prompts = prompts_data
|
||||
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts")
|
||||
|
||||
try:
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx, # Key: specify the exact frame index
|
||||
obj_id=obj_id,
|
||||
box=bbox.astype(np.float32),
|
||||
)
|
||||
for i, prompt in enumerate(prompts):
|
||||
obj_id = prompt['obj_id']
|
||||
bbox = prompt['bbox']
|
||||
confidence = prompt.get('confidence', 'unknown')
|
||||
total_count += 1
|
||||
|
||||
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
|
||||
success_count += 1
|
||||
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
|
||||
continue
|
||||
try:
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx, # Key: specify the exact frame index
|
||||
obj_id=obj_id,
|
||||
box=bbox.astype(np.float32),
|
||||
)
|
||||
|
||||
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
|
||||
continue
|
||||
|
||||
if success_count > 0:
|
||||
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
|
||||
return True
|
||||
else:
|
||||
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
|
||||
return False
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user