not working
This commit is contained in:
@@ -17,16 +17,18 @@ logger = logging.getLogger(__name__)
|
||||
class SAM2Processor:
|
||||
"""Handles SAM2-based video segmentation for human tracking."""
|
||||
|
||||
def __init__(self, checkpoint_path: str, config_path: str):
|
||||
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
|
||||
"""
|
||||
Initialize SAM2 processor.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to SAM2 checkpoint
|
||||
config_path: Path to SAM2 config file
|
||||
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
|
||||
"""
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.config_path = config_path
|
||||
self.vos_optimized = vos_optimized
|
||||
self.predictor = None
|
||||
self._initialize_predictor()
|
||||
|
||||
@@ -62,12 +64,35 @@ class SAM2Processor:
|
||||
|
||||
logger.info(f"Using SAM2 config: {config_name}")
|
||||
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
config_name, # Use just the config name, not full path
|
||||
self.checkpoint_path,
|
||||
device=device,
|
||||
overrides=dict(conf=0.95)
|
||||
)
|
||||
# Use VOS optimization if enabled and supported
|
||||
if self.vos_optimized:
|
||||
try:
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
config_name, # Use just the config name, not full path
|
||||
self.checkpoint_path,
|
||||
device=device,
|
||||
vos_optimized=True # New optimization for major speedup
|
||||
)
|
||||
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to use optimized VOS predictor: {e}")
|
||||
logger.info("Falling back to standard SAM2 predictor")
|
||||
# Fallback to standard predictor
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
config_name,
|
||||
self.checkpoint_path,
|
||||
device=device,
|
||||
overrides=dict(conf=0.95)
|
||||
)
|
||||
else:
|
||||
# Use standard predictor
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
config_name,
|
||||
self.checkpoint_path,
|
||||
device=device,
|
||||
overrides=dict(conf=0.95)
|
||||
)
|
||||
logger.info("Using standard SAM2 predictor")
|
||||
|
||||
# Enable optimizations for CUDA
|
||||
if device.type == "cuda":
|
||||
@@ -248,18 +273,30 @@ class SAM2Processor:
|
||||
Dictionary mapping frame indices to object masks
|
||||
"""
|
||||
video_segments = {}
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
logger.info("Starting SAM2 mask propagation...")
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||
video_segments[out_frame_idx] = {
|
||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
frame_count += 1
|
||||
|
||||
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
|
||||
# Log progress every 50 frames
|
||||
if frame_count % 50 == 0:
|
||||
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
|
||||
|
||||
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mask propagation: {e}")
|
||||
logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
|
||||
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
|
||||
if frame_count == 0:
|
||||
logger.error("No frames were processed - propagation failed completely")
|
||||
else:
|
||||
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
|
||||
|
||||
return video_segments
|
||||
|
||||
|
||||
10
main.py
10
main.py
@@ -277,7 +277,8 @@ def main():
|
||||
logger.info("Step 3: Initializing SAM2 processor")
|
||||
sam2_processor = SAM2Processor(
|
||||
checkpoint_path=config.get_sam2_checkpoint(),
|
||||
config_path=config.get_sam2_config()
|
||||
config_path=config.get_sam2_config(),
|
||||
vos_optimized=config.get('models.sam2_vos_optimized', False)
|
||||
)
|
||||
|
||||
# Initialize mask processor
|
||||
@@ -595,6 +596,13 @@ def main():
|
||||
logger.error(f"SAM2 processing failed for segment {segment_idx}")
|
||||
continue
|
||||
|
||||
# Check if SAM2 produced adequate results
|
||||
if len(video_segments) == 0:
|
||||
logger.error(f"SAM2 produced no frames for segment {segment_idx}")
|
||||
continue
|
||||
elif len(video_segments) < 10: # Expected many frames for a 5-second segment
|
||||
logger.warning(f"SAM2 produced very few frames ({len(video_segments)}) for segment {segment_idx} - this may indicate propagation failure")
|
||||
|
||||
# Debug what SAM2 produced
|
||||
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
|
||||
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")
|
||||
|
||||
Reference in New Issue
Block a user