Model Wrappers
This document shows how to use eva's Model Wrapper API (eva.models.wrappers
) to load different model formats from a series of sources such as PyTorch Hub, HuggingFace Model Hub and ONNX.
eva model registry
To load models from eva's FM backbone model registry, we provide the ModelFromRegistry
wrapper class:
backbone:
class_path: eva.vision.models.wrappers.ModelFromRegistry
init_args:
model_name: universal/vit_small_patch16_224_dino
model_kwargs:
out_indices: 1
out_indices=1
keyword argument, the model will return a feature map tensor, which is needed for segmentation tasks. If you ommit this argument, it will return the CLS embedding (for classification tasks).
PyTorch models
The eva framework is built on top of PyTorch Lightning and thus naturally supports loading PyTorch models.
You just need to specify the class path of your model in the backbone section of the .yaml
config file.
Note that your ModelClass
should subclass torch.nn.Module
and implement the forward()
method to return an embedding tensor of shape [1, embedding_dim]
for classification tasks or a list feature maps of shape [1, embedding_dim, patch_dim, patch_dim]
for segmentation.
Models from functions
The wrapper class eva.models.wrappers.ModelFromFunction
allows you to load models from Python functions that return torch model instances (nn.Module
).
You can either use this to load models from your own custom functions, or from public providers such as Torch Hub or timm
that expose model load functions.
torch.hub.load
The following example shows how to load a dino_vits16 model from Torch Hub using the torch.hub.load
function:
backbone:
class_path: eva.models.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
repo_or_dir: facebookresearch/dino:main
model: dino_vits16
pretrained: false
checkpoint_path: path/to/your/checkpoint.torch
Note that if a checkpoint_path
is provided, ModelFromFunction
will automatically initialize the specified model using the provided weights from that checkpoint file.
timm.create_model
Similar to the above example, we can easily load models using the common vision library timm
:
backbone:
class_path: eva.models.wrappers.ModelFromFunction
init_args:
path: timm.create_model
arguments:
model_name: resnet18
pretrained: true
timm
models
While you can load timm
models using the ModelFromFunction
wrapper class as shown in the example above, we also provide a specific wrapper class:
backbone:
class_path: eva.vision.models.wrappers.TimmModel
init_args:
model_name: vit_tiny_patch16_224
pretrained: true
out_indices=1 # to return the last feature map
model_kwargs:
dynamic_img_size: true
HuggingFace models
For loading models from HuggingFace Hub, eva provides a custom wrapper class HuggingFaceModel
which can be used as follows:
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
class_path: eva.models.networks.transforms.ExtractCLSFeatures
In the above example, the forward pass implemented by the owkin/phikon
model returns an output tensor containing the hidden states of all input tokens. In order to extract the state corresponding to the CLS token only (for classification tasks), we can specify a transformation via the tensor_transforms
argument which will be applied to the model output. For segmentation tasks, we can use the ExtractPatchFeatures
transformation instead to extract patch feature maps instead.
ONNX models
.onnx
model checkpoints can be loaded using the ONNXModel
wrapper class as follows:
Implementing custom model wrappers
You can also implement your own model wrapper classes, in case your model format is not supported by the wrapper classes that eva already provides. To do so, you need to subclass eva.models.wrappers.BaseModel
and implement the following abstract methods:
load_model
: Returns an instantiated model object & loads pre-trained model weights from a checkpoint if available.model_forward
: Implements the forward pass of the model and returns the output as atorch.Tensor
of shape[embedding_dim]
You can take the implementations of ModelFromFunction
, HuggingFaceModel
and ONNXModel
wrappers as a reference.