Skip to content

pixano_inference.impls.sam2.video

SAM2 video tracking model.

Sam2VideoModel(config)

Bases: TrackingModel

Native Ray Serve model for SAM2 video mask generation.

model_params contract:

  • path (str, required): HuggingFace model ID or local checkpoint path.
  • torch_dtype (str, default "bfloat16"): Torch dtype for autocast.
  • compile (bool, default True): Whether to torch.compile the model.
  • vos_optimized (bool, default True): Use VOS-optimised predictor.
  • propagate (bool, default True): Whether to propagate masks across the full video after adding prompts.

Any remaining keys are forwarded to the SAM2 predictor builder.

Parameters:

Name Type Description Default
config ModelDeploymentConfig

Model deployment configuration.

required
Source code in pixano_inference/impls/sam2/video.py
def __init__(self, config: ModelDeploymentConfig) -> None:
    """Initialize the model.

    Args:
        config: Model deployment configuration.
    """
    super().__init__(config)
    self._predictor: Any = None
    self._torch_dtype: Any = None
    self._propagate: bool = True

metadata property

Model metadata including path and dtype.

load_model()

Load the SAM2 video predictor.

Source code in pixano_inference/impls/sam2/video.py
def load_model(self) -> None:
    """Load the SAM2 video predictor."""
    from pixano_inference.utils.package import assert_sam2_installed

    assert_sam2_installed()

    import torch
    from sam2.build_sam import build_sam2_video_predictor, build_sam2_video_predictor_hf

    params = dict(self._config.model_params)
    path = params.pop("path")
    torch_dtype_str = params.pop("torch_dtype", "bfloat16")
    compile_model = params.pop("compile", True)
    vos_optimized = params.pop("vos_optimized", True)
    self._propagate = params.pop("propagate", True)

    device = resolve_device(self._config)
    self._torch_dtype = resolve_torch_dtype(torch_dtype_str)

    if path is not None and Path(path).exists():
        predictor = build_sam2_video_predictor(
            ckpt_path=path, mode="eval", device=device, vos_optimized=vos_optimized, **params
        )
    else:
        predictor = build_sam2_video_predictor_hf(
            model_id=path, mode="eval", device=device, vos_optimized=vos_optimized, **params
        )

    if compile_model:
        predictor = torch.compile(predictor)

    self._predictor = predictor
    logger.info("Sam2VideoModel '%s' loaded on %s (dtype=%s)", self.model_name, device, torch_dtype_str)

predict(input)

Run SAM2 video mask generation.

Parameters:

Name Type Description Default
input TrackingInput

Tracking input with video, prompts, and object IDs.

required

Returns:

Type Description
TrackingOutput

Tracking output with objects_ids, frame_indexes, and masks.

Source code in pixano_inference/impls/sam2/video.py
def predict(self, input: TrackingInput) -> TrackingOutput:
    """Run SAM2 video mask generation.

    Args:
        input: Tracking input with video, prompts, and object IDs.

    Returns:
        Tracking output with objects_ids, frame_indexes, and masks.
    """
    import torch

    from pixano_inference.schemas.rle import CompressedRLE

    objects_ids = input.objects_ids
    request_propagate = self._propagate if input.propagate is None else input.propagate
    frame_indexes = input.frame_indexes

    if len(objects_ids) != len(frame_indexes):
        raise ValueError("objects_ids and frame_indexes must have the same length.")

    validate_prompts(input.points, input.labels, input.boxes)

    num_objects = len(objects_ids)
    if input.points is not None:
        input_points = [np.array(p, dtype=np.int32) for p in input.points]
    else:
        input_points = [None] * num_objects  # type: ignore[list-item]
    if input.labels is not None:
        input_labels = [np.array(lbl, dtype=np.int32) for lbl in input.labels]
    else:
        input_labels = [None] * num_objects  # type: ignore[list-item]
    if input.boxes is not None:
        input_boxes = [np.array(b, dtype=np.int32) for b in input.boxes]
    else:
        input_boxes = [None] * num_objects  # type: ignore[list-item]

    video_segments: dict[int, dict[int, np.ndarray]] = {}

    with torch.inference_mode():
        torch.compiler.cudagraph_mark_step_begin()
        with torch.autocast(self._predictor.device.type, dtype=self._torch_dtype):
            inference_state = self._init_video_state(input.video)

            if input.keyframes is not None:
                for obj_id, keyframe in zip(objects_ids, input.keyframes, strict=False):
                    out_frame_idx, out_obj_ids, out_mask_logits = self._apply_keyframe_prompt(
                        inference_state,
                        obj_id,
                        keyframe,
                    )
                    if not request_propagate:
                        self._merge_video_segments(video_segments, out_frame_idx, out_obj_ids, out_mask_logits)
            else:
                for obj_id, frame_idx, obj_points, obj_labels, obj_box in zip(
                    objects_ids,
                    frame_indexes,
                    input_points,
                    input_labels,
                    input_boxes,
                    strict=False,
                ):
                    _, out_obj_ids, out_mask_logits = self._apply_legacy_prompt(
                        inference_state,
                        obj_id,
                        frame_idx,
                        obj_points,
                        obj_labels,
                        obj_box,
                    )
                    if not request_propagate:
                        self._merge_video_segments(video_segments, frame_idx, out_obj_ids, out_mask_logits)

            if request_propagate:
                video_segments = {}
                propagation_kwargs = self._build_propagation_kwargs(input)
                for out_frame_idx, out_obj_ids, out_mask_logits in self._predictor.propagate_in_video(
                    inference_state,
                    **propagation_kwargs,
                ):
                    self._merge_video_segments(video_segments, out_frame_idx, out_obj_ids, out_mask_logits)

    out_objects_ids: list[int] = []
    out_frame_indexes: list[int] = []
    out_masks: list[CompressedRLE] = []

    for frame_index, object_masks in video_segments.items():
        for object_id, mask in object_masks.items():
            out_objects_ids.append(object_id)
            out_frame_indexes.append(frame_index)
            out_masks.append(CompressedRLE.from_mask(mask[0].astype(np.uint8)))

    return TrackingOutput(
        objects_ids=out_objects_ids,
        frame_indexes=out_frame_indexes,
        masks=out_masks,
    )

unload()

Free resources.

Source code in pixano_inference/impls/sam2/video.py
def unload(self) -> None:
    """Free resources."""
    if self._predictor is not None:
        del self._predictor
        self._predictor = None
    gc.collect()
    try:
        import torch

        torch.cuda.empty_cache()
    except Exception:
        pass