diff --git a/vr180_streaming/sam2_streaming_simple.py b/vr180_streaming/sam2_streaming_simple.py index 203832f..ae081e4 100644 --- a/vr180_streaming/sam2_streaming_simple.py +++ b/vr180_streaming/sam2_streaming_simple.py @@ -205,10 +205,13 @@ class SAM2StreamingProcessor: if out_frame_idx == latest_frame_idx: # Combine all object masks if len(out_mask_logits) > 0: - combined_mask = np.zeros_like(out_mask_logits[0], dtype=bool) + combined_mask = None for mask_logit in out_mask_logits: mask = (mask_logit > 0.0).cpu().numpy() - combined_mask = combined_mask | mask + if combined_mask is None: + combined_mask = mask.astype(bool) + else: + combined_mask = combined_mask | mask.astype(bool) return (combined_mask * 255).astype(np.uint8) @@ -250,4 +253,82 @@ class SAM2StreamingProcessor: torch.cuda.synchronize() gc.collect() - print("🧹 Simple SAM2 streaming processor cleaned up") \ No newline at end of file + print("🧹 Simple SAM2 streaming processor cleaned up") + + def apply_mask_to_frame(self, + frame: np.ndarray, + mask: np.ndarray, + output_format: str = "alpha", + background_color: tuple = (0, 255, 0)) -> np.ndarray: + """ + Apply mask to frame with specified output format (matches chunked implementation) + + Args: + frame: Input frame (BGR) + mask: Binary mask (0-255 or boolean) + output_format: "alpha" or "greenscreen" + background_color: RGB background color for greenscreen mode + + Returns: + Processed frame + """ + if mask is None: + return frame + + # Ensure mask is 2D (handle 3D masks properly) + if mask.ndim == 3: + mask = mask.squeeze() + + # Resize mask to match frame if needed (use INTER_NEAREST for binary masks) + if mask.shape[:2] != frame.shape[:2]: + import cv2 + # Convert to uint8 for resizing, then back to bool + if mask.dtype == bool: + mask_uint8 = mask.astype(np.uint8) * 255 + else: + mask_uint8 = mask.astype(np.uint8) + + mask_resized = cv2.resize(mask_uint8, + (frame.shape[1], frame.shape[0]), + interpolation=cv2.INTER_NEAREST) + mask = mask_resized.astype(bool) if mask.dtype == bool else mask_resized + + if output_format == "alpha": + # Create RGBA output (matches chunked implementation) + output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) + output[:, :, :3] = frame + if mask.dtype == bool: + output[:, :, 3] = mask.astype(np.uint8) * 255 + else: + output[:, :, 3] = mask.astype(np.uint8) + return output + + elif output_format == "greenscreen": + # Create RGB output with background (matches chunked implementation) + output = np.full_like(frame, background_color, dtype=np.uint8) + if mask.dtype == bool: + output[mask] = frame[mask] + else: + mask_bool = mask.astype(bool) + output[mask_bool] = frame[mask_bool] + return output + + else: + raise ValueError(f"Unsupported output format: {output_format}. Use 'alpha' or 'greenscreen'") + + def get_memory_usage(self) -> Dict[str, float]: + """ + Get current memory usage statistics + + Returns: + Dictionary with memory usage info + """ + stats = {} + + if torch.cuda.is_available(): + # GPU memory stats + stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / (1024**3) + stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / (1024**3) + stats['cuda_max_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3) + + return stats \ No newline at end of file