swav
SwaV
Bases: Module
Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backbone
|
Module
|
CNN backbone for feature extraction. |
required |
num_ftrs
|
int
|
Number of input features for the projection head. |
required |
out_dim
|
int
|
Output dimension for the projection head. |
required |
n_prototypes
|
int
|
Number of prototypes to compute. |
required |
n_queues
|
int
|
Number of memory banks (queues). Should be equal to the number of high-resolution inputs. |
required |
queue_length
|
int
|
Length of the memory bank. Defaults to 0. |
0
|
start_queue_at_epoch
|
int
|
Number of the epoch at which SwaV starts using the queued features. Defaults to 0. |
0
|
n_steps_frozen_prototypes
|
int
|
Number of steps during which we keep the prototypes fixed. Defaults to 0. |
0
|
Source code in fmcib/ssl/modules/swav.py
|
|
__init__(backbone, num_ftrs, out_dim, n_prototypes, n_queues, queue_length=0, start_queue_at_epoch=0, n_steps_frozen_prototypes=0)
Initialize a SwaV model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backbone
|
Module
|
The backbone model. |
required |
num_ftrs
|
int
|
The number of input features. |
required |
out_dim
|
int
|
The dimension of the output. |
required |
n_prototypes
|
int
|
The number of prototypes. |
required |
n_queues
|
int
|
The number of queues. |
required |
queue_length
|
int
|
The length of the queue. Default is 0. |
0
|
start_queue_at_epoch
|
int
|
The epoch at which to start using the queue. Default is 0. |
0
|
n_steps_frozen_prototypes
|
int
|
The number of steps to freeze prototypes. Default is 0. |
0
|
Returns:
Type | Description |
---|---|
None |
Attributes:
Name | Type | Description |
---|---|---|
backbone |
Module
|
The backbone model. |
projection_head |
SwaVProjectionHead
|
The projection head. |
prototypes |
SwaVPrototypes
|
The prototypes. |
queues |
ModuleList
|
The queues. If n_queues > 0, this will be initialized with MemoryBankModules. |
queue_length |
int
|
The length of the queue. |
num_features_queued |
int
|
The number of features queued. |
start_queue_at_epoch |
int
|
The epoch at which to start using the queue. |
Source code in fmcib/ssl/modules/swav.py
_get_queue_prototypes(high_resolution_features, epoch=None)
Compute the queue prototypes for the given high-resolution features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
high_resolution_features
|
List[Tensor]
|
List of high-resolution feature tensors. |
required |
epoch
|
int
|
Current epoch number. Required if |
None
|
Returns:
Type | Description |
---|---|
List[Tensor] or None: List of queue prototype tensors if conditions are met, otherwise None. |
Source code in fmcib/ssl/modules/swav.py
_subforward(input)
Subforward pass to compute features for the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input
|
Tensor
|
Input image tensor. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
L2-normalized feature tensor. |
Source code in fmcib/ssl/modules/swav.py
forward(input, epoch=None, step=None)
Performs the forward pass for the SwAV model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input
|
Tuple[List[Tensor], List[Tensor]]
|
A tuple consisting of a list of high-resolution input images and a list of low-resolution input images. |
required |
epoch
|
int
|
Current training epoch. Required if |
None
|
step
|
int
|
Current training step. Required if |
None
|
Returns:
Type | Description |
---|---|
Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes, low-resolution prototypes, and queue prototypes. |