not working

This commit is contained in:
2025-07-27 14:26:20 -07:00
parent 97f12c79a4
commit 02ad4d87d2
2 changed files with 55 additions and 10 deletions

View File

@@ -17,16 +17,18 @@ logger = logging.getLogger(__name__)
class SAM2Processor: class SAM2Processor:
"""Handles SAM2-based video segmentation for human tracking.""" """Handles SAM2-based video segmentation for human tracking."""
def __init__(self, checkpoint_path: str, config_path: str): def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
""" """
Initialize SAM2 processor. Initialize SAM2 processor.
Args: Args:
checkpoint_path: Path to SAM2 checkpoint checkpoint_path: Path to SAM2 checkpoint
config_path: Path to SAM2 config file config_path: Path to SAM2 config file
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
""" """
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.config_path = config_path self.config_path = config_path
self.vos_optimized = vos_optimized
self.predictor = None self.predictor = None
self._initialize_predictor() self._initialize_predictor()
@@ -62,12 +64,35 @@ class SAM2Processor:
logger.info(f"Using SAM2 config: {config_name}") logger.info(f"Using SAM2 config: {config_name}")
# Use VOS optimization if enabled and supported
if self.vos_optimized:
try:
self.predictor = build_sam2_video_predictor( self.predictor = build_sam2_video_predictor(
config_name, # Use just the config name, not full path config_name, # Use just the config name, not full path
self.checkpoint_path, self.checkpoint_path,
device=device, device=device,
vos_optimized=True # New optimization for major speedup
)
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
except Exception as e:
logger.warning(f"Failed to use optimized VOS predictor: {e}")
logger.info("Falling back to standard SAM2 predictor")
# Fallback to standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95) overrides=dict(conf=0.95)
) )
else:
# Use standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
logger.info("Using standard SAM2 predictor")
# Enable optimizations for CUDA # Enable optimizations for CUDA
if device.type == "cuda": if device.type == "cuda":
@@ -248,18 +273,30 @@ class SAM2Processor:
Dictionary mapping frame indices to object masks Dictionary mapping frame indices to object masks
""" """
video_segments = {} video_segments = {}
frame_count = 0
try: try:
logger.info("Starting SAM2 mask propagation...")
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = { video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids) for i, out_obj_id in enumerate(out_obj_ids)
} }
frame_count += 1
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects") # Log progress every 50 frames
if frame_count % 50 == 0:
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
except Exception as e: except Exception as e:
logger.error(f"Error during mask propagation: {e}") logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
if frame_count == 0:
logger.error("No frames were processed - propagation failed completely")
else:
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
return video_segments return video_segments

10
main.py
View File

@@ -277,7 +277,8 @@ def main():
logger.info("Step 3: Initializing SAM2 processor") logger.info("Step 3: Initializing SAM2 processor")
sam2_processor = SAM2Processor( sam2_processor = SAM2Processor(
checkpoint_path=config.get_sam2_checkpoint(), checkpoint_path=config.get_sam2_checkpoint(),
config_path=config.get_sam2_config() config_path=config.get_sam2_config(),
vos_optimized=config.get('models.sam2_vos_optimized', False)
) )
# Initialize mask processor # Initialize mask processor
@@ -595,6 +596,13 @@ def main():
logger.error(f"SAM2 processing failed for segment {segment_idx}") logger.error(f"SAM2 processing failed for segment {segment_idx}")
continue continue
# Check if SAM2 produced adequate results
if len(video_segments) == 0:
logger.error(f"SAM2 produced no frames for segment {segment_idx}")
continue
elif len(video_segments) < 10: # Expected many frames for a 5-second segment
logger.warning(f"SAM2 produced very few frames ({len(video_segments)}) for segment {segment_idx} - this may indicate propagation failure")
# Debug what SAM2 produced # Debug what SAM2 produced
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}") logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames") logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")