swav
SwaV
Bases: Module
Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backbone
|
Module
|
CNN backbone for feature extraction. |
required |
num_ftrs
|
int
|
Number of input features for the projection head. |
required |
out_dim
|
int
|
Output dimension for the projection head. |
required |
n_prototypes
|
int
|
Number of prototypes to compute. |
required |
n_queues
|
int
|
Number of memory banks (queues). Should be equal to the number of high-resolution inputs. |
required |
queue_length
|
int
|
Length of the memory bank. Defaults to 0. |
0
|
start_queue_at_epoch
|
int
|
Number of the epoch at which SwaV starts using the queued features. Defaults to 0. |
0
|
n_steps_frozen_prototypes
|
int
|
Number of steps during which we keep the prototypes fixed. Defaults to 0. |
0
|
Source code in fmcib/ssl/modules/swav.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
|
__init__(backbone, num_ftrs, out_dim, n_prototypes, n_queues, queue_length=0, start_queue_at_epoch=0, n_steps_frozen_prototypes=0)
Initialize a SwaV model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backbone
|
Module
|
The backbone model. |
required |
num_ftrs
|
int
|
The number of input features. |
required |
out_dim
|
int
|
The dimension of the output. |
required |
n_prototypes
|
int
|
The number of prototypes. |
required |
n_queues
|
int
|
The number of queues. |
required |
queue_length
|
int
|
The length of the queue. Default is 0. |
0
|
start_queue_at_epoch
|
int
|
The epoch at which to start using the queue. Default is 0. |
0
|
n_steps_frozen_prototypes
|
int
|
The number of steps to freeze prototypes. Default is 0. |
0
|
Returns:
Type | Description |
---|---|
None |
Attributes:
Name | Type | Description |
---|---|---|
backbone |
Module
|
The backbone model. |
projection_head |
SwaVProjectionHead
|
The projection head. |
prototypes |
SwaVPrototypes
|
The prototypes. |
queues |
ModuleList
|
The queues. If n_queues > 0, this will be initialized with MemoryBankModules. |
queue_length |
int
|
The length of the queue. |
num_features_queued |
int
|
The number of features queued. |
start_queue_at_epoch |
int
|
The epoch at which to start using the queue. |
Source code in fmcib/ssl/modules/swav.py
_get_queue_prototypes(high_resolution_features, epoch=None)
Compute the queue prototypes for the given high-resolution features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
high_resolution_features
|
List[Tensor]
|
List of high-resolution feature tensors. |
required |
epoch
|
int
|
Current epoch number. Required if |
None
|
Returns:
Type | Description |
---|---|
List[Tensor] or None: List of queue prototype tensors if conditions are met, otherwise None. |
Source code in fmcib/ssl/modules/swav.py
_subforward(input)
Subforward pass to compute features for the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input
|
Tensor
|
Input image tensor. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
L2-normalized feature tensor. |
Source code in fmcib/ssl/modules/swav.py
forward(input, epoch=None, step=None)
Performs the forward pass for the SwAV model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input
|
Tuple[List[Tensor], List[Tensor]]
|
A tuple consisting of a list of high-resolution input images and a list of low-resolution input images. |
required |
epoch
|
int
|
Current training epoch. Required if |
None
|
step
|
int
|
Current training step. Required if |
None
|
Returns:
Type | Description |
---|---|
Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes, low-resolution prototypes, and queue prototypes. |