swav_loss
SwaVLoss
Bases: SwaVLoss
A class representing a custom SwaV loss function.
Attributes:
Name | Type | Description |
---|---|---|
temperature |
float
|
The temperature parameter for the loss calculation. Default is 0.1. |
sinkhorn_iterations |
int
|
The number of iterations for Sinkhorn algorithm. Default is 3. |
sinkhorn_epsilon |
float
|
The epsilon parameter for Sinkhorn algorithm. Default is 0.05. |
sinkhorn_gather_distributed |
bool
|
Whether to gather distributed results for Sinkhorn algorithm. Default is False. |
Source code in fmcib/ssl/losses/swav_loss.py
__init__(temperature=0.1, sinkhorn_iterations=3, sinkhorn_epsilon=0.05, sinkhorn_gather_distributed=False)
Initialize the object with specified parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
temperature
|
float
|
The temperature parameter. Default is 0.1. |
0.1
|
sinkhorn_iterations
|
int
|
The number of Sinkhorn iterations. Default is 3. |
3
|
sinkhorn_epsilon
|
float
|
The epsilon parameter for Sinkhorn algorithm. Default is 0.05. |
0.05
|
sinkhorn_gather_distributed
|
bool
|
Whether to use distributed computation for Sinkhorn algorithm. Default is False. |
False
|
Source code in fmcib/ssl/losses/swav_loss.py
forward(pred)
Perform a forward pass of the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred
|
tuple
|
A tuple containing the predicted outputs for high resolution, low resolution, and queue. |
required |
Returns:
Type | Description |
---|---|
The output of the forward pass. |