not working
This commit is contained in:
@@ -17,16 +17,18 @@ logger = logging.getLogger(__name__)
|
|||||||
class SAM2Processor:
|
class SAM2Processor:
|
||||||
"""Handles SAM2-based video segmentation for human tracking."""
|
"""Handles SAM2-based video segmentation for human tracking."""
|
||||||
|
|
||||||
def __init__(self, checkpoint_path: str, config_path: str):
|
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
|
||||||
"""
|
"""
|
||||||
Initialize SAM2 processor.
|
Initialize SAM2 processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_path: Path to SAM2 checkpoint
|
checkpoint_path: Path to SAM2 checkpoint
|
||||||
config_path: Path to SAM2 config file
|
config_path: Path to SAM2 config file
|
||||||
|
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
|
||||||
"""
|
"""
|
||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
|
self.vos_optimized = vos_optimized
|
||||||
self.predictor = None
|
self.predictor = None
|
||||||
self._initialize_predictor()
|
self._initialize_predictor()
|
||||||
|
|
||||||
@@ -62,12 +64,35 @@ class SAM2Processor:
|
|||||||
|
|
||||||
logger.info(f"Using SAM2 config: {config_name}")
|
logger.info(f"Using SAM2 config: {config_name}")
|
||||||
|
|
||||||
self.predictor = build_sam2_video_predictor(
|
# Use VOS optimization if enabled and supported
|
||||||
config_name, # Use just the config name, not full path
|
if self.vos_optimized:
|
||||||
self.checkpoint_path,
|
try:
|
||||||
device=device,
|
self.predictor = build_sam2_video_predictor(
|
||||||
overrides=dict(conf=0.95)
|
config_name, # Use just the config name, not full path
|
||||||
)
|
self.checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
vos_optimized=True # New optimization for major speedup
|
||||||
|
)
|
||||||
|
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to use optimized VOS predictor: {e}")
|
||||||
|
logger.info("Falling back to standard SAM2 predictor")
|
||||||
|
# Fallback to standard predictor
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
config_name,
|
||||||
|
self.checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
overrides=dict(conf=0.95)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use standard predictor
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
config_name,
|
||||||
|
self.checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
overrides=dict(conf=0.95)
|
||||||
|
)
|
||||||
|
logger.info("Using standard SAM2 predictor")
|
||||||
|
|
||||||
# Enable optimizations for CUDA
|
# Enable optimizations for CUDA
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
@@ -248,18 +273,30 @@ class SAM2Processor:
|
|||||||
Dictionary mapping frame indices to object masks
|
Dictionary mapping frame indices to object masks
|
||||||
"""
|
"""
|
||||||
video_segments = {}
|
video_segments = {}
|
||||||
|
frame_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info("Starting SAM2 mask propagation...")
|
||||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||||
video_segments[out_frame_idx] = {
|
video_segments[out_frame_idx] = {
|
||||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
for i, out_obj_id in enumerate(out_obj_ids)
|
for i, out_obj_id in enumerate(out_obj_ids)
|
||||||
}
|
}
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
|
# Log progress every 50 frames
|
||||||
|
if frame_count % 50 == 0:
|
||||||
|
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
|
||||||
|
|
||||||
|
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during mask propagation: {e}")
|
logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
|
||||||
|
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
|
||||||
|
if frame_count == 0:
|
||||||
|
logger.error("No frames were processed - propagation failed completely")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
|
||||||
|
|
||||||
return video_segments
|
return video_segments
|
||||||
|
|
||||||
|
|||||||
10
main.py
10
main.py
@@ -277,7 +277,8 @@ def main():
|
|||||||
logger.info("Step 3: Initializing SAM2 processor")
|
logger.info("Step 3: Initializing SAM2 processor")
|
||||||
sam2_processor = SAM2Processor(
|
sam2_processor = SAM2Processor(
|
||||||
checkpoint_path=config.get_sam2_checkpoint(),
|
checkpoint_path=config.get_sam2_checkpoint(),
|
||||||
config_path=config.get_sam2_config()
|
config_path=config.get_sam2_config(),
|
||||||
|
vos_optimized=config.get('models.sam2_vos_optimized', False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize mask processor
|
# Initialize mask processor
|
||||||
@@ -595,6 +596,13 @@ def main():
|
|||||||
logger.error(f"SAM2 processing failed for segment {segment_idx}")
|
logger.error(f"SAM2 processing failed for segment {segment_idx}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if SAM2 produced adequate results
|
||||||
|
if len(video_segments) == 0:
|
||||||
|
logger.error(f"SAM2 produced no frames for segment {segment_idx}")
|
||||||
|
continue
|
||||||
|
elif len(video_segments) < 10: # Expected many frames for a 5-second segment
|
||||||
|
logger.warning(f"SAM2 produced very few frames ({len(video_segments)}) for segment {segment_idx} - this may indicate propagation failure")
|
||||||
|
|
||||||
# Debug what SAM2 produced
|
# Debug what SAM2 produced
|
||||||
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
|
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
|
||||||
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")
|
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")
|
||||||
|
|||||||
Reference in New Issue
Block a user