Skip to content

ntxent_loss

NTXentLoss

Bases: NTXentLoss

NTXentNegativeMinedLoss: NTXentLoss with explicitly mined negatives

Source code in fmcib/ssl/losses/ntxent_loss.py
class NTXentLoss(lightly_NTXentLoss):
    """
    NTXentNegativeMinedLoss:
    NTXentLoss with explicitly mined negatives
    """

    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
        """
        Initialize an instance of the class.

        Args:
            temperature (float, optional): The temperature parameter for the instance. Defaults to 0.1.
            gather_distributed (bool, optional): Whether to gather distributed data. Defaults to False.
        """
        super().__init__(temperature, gather_distributed)

    def forward(self, out: List):
        """
        Forward pass through Negative mining contrastive Cross-Entropy Loss.

        Args:
            out (List[torch.Tensor]): List of tensors

        Returns:
            float: Contrastive Cross Entropy Loss value.
        """
        return super().forward(*out)

__init__(temperature=0.1, gather_distributed=False)

Initialize an instance of the class.

Parameters:

Name Type Description Default
temperature float

The temperature parameter for the instance. Defaults to 0.1.

0.1
gather_distributed bool

Whether to gather distributed data. Defaults to False.

False
Source code in fmcib/ssl/losses/ntxent_loss.py
def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
    """
    Initialize an instance of the class.

    Args:
        temperature (float, optional): The temperature parameter for the instance. Defaults to 0.1.
        gather_distributed (bool, optional): Whether to gather distributed data. Defaults to False.
    """
    super().__init__(temperature, gather_distributed)

forward(out)

Forward pass through Negative mining contrastive Cross-Entropy Loss.

Parameters:

Name Type Description Default
out List[Tensor]

List of tensors

required

Returns:

Name Type Description
float

Contrastive Cross Entropy Loss value.

Source code in fmcib/ssl/losses/ntxent_loss.py
def forward(self, out: List):
    """
    Forward pass through Negative mining contrastive Cross-Entropy Loss.

    Args:
        out (List[torch.Tensor]): List of tensors

    Returns:
        float: Contrastive Cross Entropy Loss value.
    """
    return super().forward(*out)