Skip to content

Networks

eva.vision.models.networks.ABMIL

Bases: Module

ABMIL network for multiple instance learning classification tasks.

Takes an array of patch level embeddings per slide as input. This implementation supports batched inputs of shape (batch_size, n_instances, input_size). For slides with less than n_instances patches, you can apply padding and provide a mask tensor to the forward pass.

The original implementation from [1] was used as a reference: https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py

Notes
  • use_bias: The paper didn't use bias in their formalism, but their published example code inadvertently does.
  • To prevent dot product similarities near-equal due to concentration of measure as a consequence of large input embedding dimensionality (>128), we added the option to project the input embeddings to a lower dimensionality

[1] Maximilian Ilse, Jakub M. Tomczak, Max Welling, "Attention-based Deep Multiple Instance Learning", 2018 https://arxiv.org/abs/1802.04712

Parameters:

Name Type Description Default
input_size int

input embedding dimension

required
output_size int

number of classes

required
projected_input_size int | None

size of the projected input. if None, no projection is performed.

required
hidden_size_attention int

hidden dimension in attention network

128
hidden_sizes_mlp tuple

dimensions for hidden layers in last mlp

(128, 64)
use_bias bool

whether to use bias in the attention network

True
dropout_input_embeddings float

dropout rate for the input embeddings

0.0
dropout_attention float

dropout rate for the attention network and classifier

0.0
dropout_mlp float

dropout rate for the final MLP network

0.0
pad_value int | float | None

Value indicating padding in the input tensor. If specified, entries with this value in the will be masked. If set to None, no masking is applied.

float('-inf')
Source code in src/eva/vision/models/networks/abmil.py
def __init__(
    self,
    input_size: int,
    output_size: int,
    projected_input_size: int | None,
    hidden_size_attention: int = 128,
    hidden_sizes_mlp: tuple = (128, 64),
    use_bias: bool = True,
    dropout_input_embeddings: float = 0.0,
    dropout_attention: float = 0.0,
    dropout_mlp: float = 0.0,
    pad_value: int | float | None = float("-inf"),
) -> None:
    """Initializes the ABMIL network.

    Args:
        input_size: input embedding dimension
        output_size: number of classes
        projected_input_size: size of the projected input. if `None`, no projection is
            performed.
        hidden_size_attention: hidden dimension in attention network
        hidden_sizes_mlp: dimensions for hidden layers in last mlp
        use_bias: whether to use bias in the attention network
        dropout_input_embeddings: dropout rate for the input embeddings
        dropout_attention: dropout rate for the attention network and classifier
        dropout_mlp: dropout rate for the final MLP network
        pad_value: Value indicating padding in the input tensor. If specified, entries with
            this value in the will be masked. If set to `None`, no masking is applied.
    """
    super().__init__()

    self._pad_value = pad_value

    if projected_input_size:
        self.projector = nn.Sequential(
            nn.Linear(input_size, projected_input_size, bias=True),
            nn.Dropout(p=dropout_input_embeddings),
        )
        input_size = projected_input_size
    else:
        self.projector = nn.Dropout(p=dropout_input_embeddings)

    self.gated_attention = GatedAttention(
        input_dim=input_size,
        hidden_dim=hidden_size_attention,
        dropout=dropout_attention,
        n_classes=1,
        use_bias=use_bias,
    )

    self.classifier = MLP(
        input_size=input_size,
        output_size=output_size,
        hidden_layer_sizes=hidden_sizes_mlp,
        dropout=dropout_mlp,
        hidden_activation_fn=nn.ReLU,
    )

forward

Forward pass.

Parameters:

Name Type Description Default
input_tensor Tensor

Tensor with expected shape of (batch_size, n_instances, input_size).

required
Source code in src/eva/vision/models/networks/abmil.py
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
    """Forward pass.

    Args:
        input_tensor: Tensor with expected shape of (batch_size, n_instances, input_size).
    """
    input_tensor, mask = self._mask_values(input_tensor, self._pad_value)

    # (batch_size, n_instances, input_size) -> (batch_size, n_instances, projected_input_size)
    input_tensor = self.projector(input_tensor)

    attention_logits = self.gated_attention(input_tensor)  # (batch_size, n_instances, 1)
    if mask is not None:
        # fill masked values with -inf, which will yield 0s after softmax
        attention_logits = attention_logits.masked_fill(mask, float("-inf"))

    attention_weights = nn.functional.softmax(attention_logits, dim=1)
    # (batch_size, n_instances, 1)

    attention_result = torch.matmul(torch.transpose(attention_weights, 1, 2), input_tensor)
    # (batch_size, 1, hidden_size_attention)

    attention_result = torch.squeeze(attention_result, 1)  # (batch_size, hidden_size_attention)

    return self.classifier(attention_result)  # (batch_size, output_size)