cupy and mask
This commit is contained in:
@@ -205,10 +205,13 @@ class SAM2StreamingProcessor:
|
|||||||
if out_frame_idx == latest_frame_idx:
|
if out_frame_idx == latest_frame_idx:
|
||||||
# Combine all object masks
|
# Combine all object masks
|
||||||
if len(out_mask_logits) > 0:
|
if len(out_mask_logits) > 0:
|
||||||
combined_mask = np.zeros_like(out_mask_logits[0], dtype=bool)
|
combined_mask = None
|
||||||
for mask_logit in out_mask_logits:
|
for mask_logit in out_mask_logits:
|
||||||
mask = (mask_logit > 0.0).cpu().numpy()
|
mask = (mask_logit > 0.0).cpu().numpy()
|
||||||
combined_mask = combined_mask | mask
|
if combined_mask is None:
|
||||||
|
combined_mask = mask.astype(bool)
|
||||||
|
else:
|
||||||
|
combined_mask = combined_mask | mask.astype(bool)
|
||||||
|
|
||||||
return (combined_mask * 255).astype(np.uint8)
|
return (combined_mask * 255).astype(np.uint8)
|
||||||
|
|
||||||
@@ -251,3 +254,81 @@ class SAM2StreamingProcessor:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print("🧹 Simple SAM2 streaming processor cleaned up")
|
print("🧹 Simple SAM2 streaming processor cleaned up")
|
||||||
|
|
||||||
|
def apply_mask_to_frame(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
mask: np.ndarray,
|
||||||
|
output_format: str = "alpha",
|
||||||
|
background_color: tuple = (0, 255, 0)) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply mask to frame with specified output format (matches chunked implementation)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (BGR)
|
||||||
|
mask: Binary mask (0-255 or boolean)
|
||||||
|
output_format: "alpha" or "greenscreen"
|
||||||
|
background_color: RGB background color for greenscreen mode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed frame
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
# Ensure mask is 2D (handle 3D masks properly)
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.squeeze()
|
||||||
|
|
||||||
|
# Resize mask to match frame if needed (use INTER_NEAREST for binary masks)
|
||||||
|
if mask.shape[:2] != frame.shape[:2]:
|
||||||
|
import cv2
|
||||||
|
# Convert to uint8 for resizing, then back to bool
|
||||||
|
if mask.dtype == bool:
|
||||||
|
mask_uint8 = mask.astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
|
mask_uint8 = mask.astype(np.uint8)
|
||||||
|
|
||||||
|
mask_resized = cv2.resize(mask_uint8,
|
||||||
|
(frame.shape[1], frame.shape[0]),
|
||||||
|
interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask_resized.astype(bool) if mask.dtype == bool else mask_resized
|
||||||
|
|
||||||
|
if output_format == "alpha":
|
||||||
|
# Create RGBA output (matches chunked implementation)
|
||||||
|
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
||||||
|
output[:, :, :3] = frame
|
||||||
|
if mask.dtype == bool:
|
||||||
|
output[:, :, 3] = mask.astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
|
output[:, :, 3] = mask.astype(np.uint8)
|
||||||
|
return output
|
||||||
|
|
||||||
|
elif output_format == "greenscreen":
|
||||||
|
# Create RGB output with background (matches chunked implementation)
|
||||||
|
output = np.full_like(frame, background_color, dtype=np.uint8)
|
||||||
|
if mask.dtype == bool:
|
||||||
|
output[mask] = frame[mask]
|
||||||
|
else:
|
||||||
|
mask_bool = mask.astype(bool)
|
||||||
|
output[mask_bool] = frame[mask_bool]
|
||||||
|
return output
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported output format: {output_format}. Use 'alpha' or 'greenscreen'")
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Get current memory usage statistics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with memory usage info
|
||||||
|
"""
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# GPU memory stats
|
||||||
|
stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / (1024**3)
|
||||||
|
stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / (1024**3)
|
||||||
|
stats['cuda_max_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
|
||||||
|
return stats
|
||||||
Reference in New Issue
Block a user