not working

This commit is contained in:
2025-07-27 14:26:20 -07:00
parent 97f12c79a4
commit 02ad4d87d2
2 changed files with 55 additions and 10 deletions

View File

@@ -17,16 +17,18 @@ logger = logging.getLogger(__name__)
class SAM2Processor:
"""Handles SAM2-based video segmentation for human tracking."""
def __init__(self, checkpoint_path: str, config_path: str):
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
"""
Initialize SAM2 processor.
Args:
checkpoint_path: Path to SAM2 checkpoint
config_path: Path to SAM2 config file
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
"""
self.checkpoint_path = checkpoint_path
self.config_path = config_path
self.vos_optimized = vos_optimized
self.predictor = None
self._initialize_predictor()
@@ -62,12 +64,35 @@ class SAM2Processor:
logger.info(f"Using SAM2 config: {config_name}")
self.predictor = build_sam2_video_predictor(
config_name, # Use just the config name, not full path
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
# Use VOS optimization if enabled and supported
if self.vos_optimized:
try:
self.predictor = build_sam2_video_predictor(
config_name, # Use just the config name, not full path
self.checkpoint_path,
device=device,
vos_optimized=True # New optimization for major speedup
)
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
except Exception as e:
logger.warning(f"Failed to use optimized VOS predictor: {e}")
logger.info("Falling back to standard SAM2 predictor")
# Fallback to standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
else:
# Use standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
logger.info("Using standard SAM2 predictor")
# Enable optimizations for CUDA
if device.type == "cuda":
@@ -248,18 +273,30 @@ class SAM2Processor:
Dictionary mapping frame indices to object masks
"""
video_segments = {}
frame_count = 0
try:
logger.info("Starting SAM2 mask propagation...")
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
frame_count += 1
# Log progress every 50 frames
if frame_count % 50 == 0:
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
except Exception as e:
logger.error(f"Error during mask propagation: {e}")
logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
if frame_count == 0:
logger.error("No frames were processed - propagation failed completely")
else:
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
return video_segments