ntxent_mined_loss
Contrastive Loss Functions
NTXentNegativeMinedLoss
Bases: Module
NTXentNegativeMinedLoss: NTXentLoss with explicitly mined negatives
Parameters:
Name | Type | Description | Default |
---|---|---|---|
temperature
|
float
|
The temperature parameter for the loss calculation. Default is 0.1. |
0.1
|
gather_distributed
|
bool
|
Whether to gather hidden representations from other processes in a distributed setting. Default is False. |
False
|
Raises:
Type | Description |
---|---|
ValueError
|
If the absolute value of temperature is less than 1e-8. |
Source code in fmcib/ssl/losses/ntxent_mined_loss.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
__init__(temperature=0.1, gather_distributed=False)
Initialize the NTXentNegativeMinedLoss object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
temperature
|
float
|
The temperature parameter for the loss function. Defaults to 0.1. |
0.1
|
gather_distributed
|
bool
|
Whether to use distributed gathering or not. Defaults to False. |
False
|
Raises:
Type | Description |
---|---|
ValueError
|
If the absolute value of the temperature is too small. |
Attributes:
Name | Type | Description |
---|---|---|
temperature |
float
|
The temperature parameter for the loss function. |
gather_distributed |
bool
|
Whether to use distributed gathering or not. |
cross_entropy |
CrossEntropyLoss
|
The cross entropy loss function. |
eps |
float
|
A small value to avoid division by zero. |
Source code in fmcib/ssl/losses/ntxent_mined_loss.py
forward(out)
Forward pass through Negative mining contrastive Cross-Entropy Loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
out
|
Dict
|
Dictionary with |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: Contrastive Cross Entropy Loss value. |
Raises:
Type | Description |
---|---|
AssertionError
|
If |