Skip to content

Wrappers

Reference information for the model Wrappers API.

eva.vision.models.wrappers.TimmModel

Bases: BaseModel

Model wrapper for timm models.

Note that only models with forward_intermediates method are currently supported.

Parameters:

Name Type Description Default
model_name str

Name of model to instantiate.

required
pretrained bool

If set to True, load pretrained ImageNet-1k weights.

True
checkpoint_path str

Path of checkpoint to load.

''
out_indices int | Tuple[int, ...] | None

Returns last n blocks if int, all if None, select matching indices if sequence.

None
model_kwargs Dict[str, Any] | None

Extra model arguments.

None
tensor_transforms Callable | None

The transforms to apply to the output tensor produced by the model.

None
Source code in src/eva/vision/models/wrappers/from_timm.py
def __init__(
    self,
    model_name: str,
    pretrained: bool = True,
    checkpoint_path: str = "",
    out_indices: int | Tuple[int, ...] | None = None,
    model_kwargs: Dict[str, Any] | None = None,
    tensor_transforms: Callable | None = None,
) -> None:
    """Initializes the encoder.

    Args:
        model_name: Name of model to instantiate.
        pretrained: If set to `True`, load pretrained ImageNet-1k weights.
        checkpoint_path: Path of checkpoint to load.
        out_indices: Returns last n blocks if `int`, all if `None`, select
            matching indices if sequence.
        model_kwargs: Extra model arguments.
        tensor_transforms: The transforms to apply to the output tensor
            produced by the model.
    """
    super().__init__(tensor_transforms=tensor_transforms)

    self._model_name = model_name
    self._pretrained = pretrained
    self._checkpoint_path = checkpoint_path
    self._out_indices = out_indices
    self._model_kwargs = model_kwargs or {}

    self.load_model()

load_model

Builds and loads the timm model as feature extractor.

Source code in src/eva/vision/models/wrappers/from_timm.py
@override
def load_model(self) -> None:
    """Builds and loads the timm model as feature extractor."""
    self._model = timm.create_model(
        model_name=self._model_name,
        pretrained=True if self._checkpoint_path else self._pretrained,
        pretrained_cfg=self._pretrained_cfg,
        out_indices=self._out_indices,
        features_only=self._out_indices is not None,
        **self._model_kwargs,
    )
    TimmModel.__name__ = self._model_name