diff --git a/README.md b/README.md
index f8eb0c7b..9c29da8e 100644
--- a/README.md
+++ b/README.md
@@ -18,16 +18,16 @@
## News
+**July 25**: v2.3.0
+- Added [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
+- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).
+
**June 18**: v2.2.0
- Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss).
- Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss).
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0).
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).
-**April 5**: v2.1.0
-- Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss)
-- Thanks you [interestingzhuo](https://github.com/interestingzhuo).
-
## Documentation
- [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/)
@@ -227,7 +227,7 @@ Thanks to the contributors who made pull requests!
| Contributor | Highlights |
| -- | -- |
-|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
+|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
- [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
|[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons |
|[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper|
|[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) |
@@ -246,6 +246,7 @@ Thanks to the contributors who made pull requests!
| [layumi](https://github.com/layumi) | [InstanceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) |
| [NoTody](https://github.com/NoTody) | Helped add `ref_emb` and `ref_labels` to the distributed wrappers. |
| [ElisonSherton](https://github.com/ElisonSherton) | Fixed an edge case in ArcFaceLoss. |
+| [stompsjo](https://github.com/stompsjo) | Improved documentation for NTXentLoss |
| [z1w](https://github.com/z1w) | |
| [thinline72](https://github.com/thinline72) | |
| [tpanum](https://github.com/tpanum) | |
@@ -259,6 +260,7 @@ Thanks to the contributors who made pull requests!
| [michaeldeyzel](https://github.com/michaeldeyzel) | |
| [HSinger04](https://github.com/HSinger04) | |
| [rheum](https://github.com/rheum) | |
+| [bot66](https://github.com/bot66) | |
diff --git a/docs/losses.md b/docs/losses.md
index d5c1f4b4..85509cf9 100644
--- a/docs/losses.md
+++ b/docs/losses.md
@@ -807,6 +807,27 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse
- [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf){target=_blank}
- [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank}
- [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank}
+
+??? "How exactly is the NTXentLoss computed?"
+
+ In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by itself and all negative pairs in the batch that have the same "anchor" embedding (`k_i in K`).
+
+ - What does "anchor" mean? Let's say we have 3 pairs specified by batch indices: (0, 1), (0, 2), (1, 0). The first two pairs start with 0, so they have the same anchor. The third pair has the same indices as the first pair, but the order is different, so it does not have the same anchor.
+
+ Given `embeddings` with corresponding `labels`, positive pairs `(embeddings[i], embeddings[j])` are defined when `labels[i] == labels[j]`. Now let's look at an example loss calculation:
+
+ Consider `labels = [0, 0, 1, 2]`. Two losses will be computed:
+
+ * A positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`.
+
+ * A positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`.
+
+ Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used.
+
+ Note that an anchor can belong to multiple positive pairs if its label is present multiple times in `labels`.
+
+ Are you trying to use `NTXentLoss` for self-supervised learning? Specifically, do you have two sets of embeddings which are derived from data that are augmented versions of each other? If so, you can skip the step of creating the `labels` array, by wrapping `NTXentLoss` with [`SelfSupervisedLoss`](losses.md#selfsupervisedloss).
+
```python
losses.NTXentLoss(temperature=0.07, **kwargs)
```
diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py
index a0ba7407..cfff0813 100644
--- a/src/pytorch_metric_learning/losses/__init__.py
+++ b/src/pytorch_metric_learning/losses/__init__.py
@@ -35,3 +35,4 @@
from .triplet_margin_loss import TripletMarginLoss
from .tuplet_margin_loss import TupletMarginLoss
from .vicreg_loss import VICRegLoss
+from .multilabel_supcon_loss import MultiSupConLoss, CrossBatchMemory4MultiLabel
\ No newline at end of file
diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py
new file mode 100644
index 00000000..c8293918
--- /dev/null
+++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py
@@ -0,0 +1,378 @@
+import torch
+
+from ..distances import CosineSimilarity
+from ..reducers import AvgNonZeroReducer
+from ..utils import common_functions as c_f
+from ..utils import loss_and_miner_utils as lmu
+from ..utils.module_with_records import ModuleWithRecords
+from .generic_pair_loss import GenericPairLoss
+from .base_loss_wrapper import BaseLossWrapper
+
+# adapted from https://github.com/HobbitLong/SupContrast
+# modified for multi-supcon
+class MultiSupConLoss(GenericPairLoss):
+ """
+ Args:
+ num_classes: number of classes
+ temperature: temperature for scaling the similarity matrix
+ threshold: threshold for jaccard similarity
+
+ Inputs:
+ embeddings: tensor of size (batch_size, embedding_size)
+ labels: tensor of size (batch_size, num_classes)
+ each row is a binary vector of size num_classes that only has 1s for the positive
+ labels, and 0s for the negative labels
+ indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix)
+ or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix)
+ Can also be left as None
+ ref_emb: tensor of size (batch_size, embedding_size)
+ """
+ def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs):
+ super().__init__(mat_based_loss=True, **kwargs)
+ self.temperature = temperature
+ self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False)
+ self.num_classes = num_classes
+ self.threshold = threshold
+
+ def _compute_loss(self, mat, pos_mask, neg_mask, multi_val):
+ if pos_mask.bool().any() and neg_mask.bool().any():
+ # if dealing with actual distances, use negative distances
+ if not self.distance.is_inverted:
+ mat = -mat
+ mat = mat / self.temperature
+ mat_max, _ = mat.max(dim=1, keepdim=True)
+ mat = mat - mat_max.detach() # for numerical stability
+
+ denominator = lmu.logsumexp(
+ mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1
+ )
+ log_prob = mat - denominator
+ mean_log_prob_pos = (multi_val * log_prob * pos_mask).sum(dim=1) / (
+ pos_mask.sum(dim=1) + c_f.small_val(mat.dtype)
+ )
+
+ return {
+ "loss": {
+ "losses": -mean_log_prob_pos,
+ "indices": c_f.torch_arange_from_size(mat),
+ "reduction_type": "element",
+ }
+ }
+ return self.zero_losses()
+
+ def get_default_reducer(self):
+ return AvgNonZeroReducer()
+
+ def get_default_distance(self):
+ return CosineSimilarity()
+
+ # ==== class methods below are overriden for adaptability to multi-supcon ====
+
+ def mat_based_loss(self, mat, indices_tuple):
+ a1, p, a2, n, jaccard_mat = indices_tuple
+ pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat)
+ pos_mask[a1, p] = 1
+ neg_mask[a2, n] = 1
+ return self._compute_loss(mat, pos_mask, neg_mask, jaccard_mat)
+
+ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
+ c_f.labels_or_indices_tuple_required(labels, indices_tuple)
+ indices_tuple = convert_to_pairs(
+ indices_tuple,
+ labels,
+ ref_labels,
+ threshold=self.threshold)
+ if all(len(x) <= 1 for x in indices_tuple):
+ return self.zero_losses()
+ mat = self.distance(embeddings, ref_emb)
+ return self.loss_method(mat, indices_tuple)
+
+ def forward(
+ self, embeddings, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None
+ ):
+ """
+ Args:
+ embeddings: tensor of size (batch_size, embedding_size)
+ labels: tensor of size (batch_size, num_classes)
+ each row is a binary vector of size num_classes that only has 1s for the positive
+ labels, and 0s for the negative labels
+ indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix)
+ or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix)
+ Can also be left as None
+ ref_emb: tensor of size (batch_size, embedding_size)
+ Returns: the loss
+ """
+ self.reset_stats()
+ check_shapes_multilabels(embeddings, labels)
+ ref_emb, ref_labels = set_ref_emb(embeddings, labels, ref_emb, ref_labels)
+ loss_dict = self.compute_loss(
+ embeddings, labels, indices_tuple, ref_emb, ref_labels
+ )
+ self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
+ return self.reducer(loss_dict, embeddings, labels)
+
+ # =========================================================================
+
+
+# ================== cross batch memory for multi-supcon ==================
+class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords):
+ def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs):
+ super().__init__(loss=loss, **kwargs)
+ self.loss = loss
+ self.miner = miner
+ self.embedding_size = embedding_size
+ self.memory_size = memory_size
+ self.num_classes = loss.num_classes
+ self.reset_queue()
+ self.add_to_recordable_attributes(
+ list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False
+ )
+
+ @staticmethod
+ def supported_losses():
+ return [
+ "MultiSupConLoss"
+ ]
+
+ @classmethod
+ def check_loss_support(cls, loss_name):
+ if loss_name not in cls.supported_losses():
+ raise Exception(f"CrossBatchMemory not supported for {loss_name}")
+
+ def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None):
+ if indices_tuple is not None and enqueue_mask is not None:
+ raise ValueError("indices_tuple and enqueue_mask are mutually exclusive")
+ if enqueue_mask is not None:
+ assert len(enqueue_mask) == len(embeddings)
+ else:
+ assert len(embeddings) <= len(self.embedding_memory)
+ self.reset_stats()
+ device = embeddings.device
+ labels = c_f.to_device(labels, device=device)
+ self.embedding_memory = c_f.to_device(
+ self.embedding_memory, device=device, dtype=embeddings.dtype
+ )
+ self.label_memory = c_f.to_device(
+ self.label_memory, device=device, dtype=labels.dtype
+ )
+
+ if enqueue_mask is not None:
+ emb_for_queue = embeddings[enqueue_mask]
+ labels_for_queue = labels[enqueue_mask]
+ embeddings = embeddings[~enqueue_mask]
+ labels = labels[~enqueue_mask]
+ do_remove_self_comparisons = False
+ else:
+ emb_for_queue = embeddings
+ labels_for_queue = labels
+ do_remove_self_comparisons = True
+
+ queue_batch_size = len(emb_for_queue)
+ self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size)
+
+ if not self.has_been_filled:
+ E_mem = self.embedding_memory[: self.queue_idx]
+ L_mem = self.label_memory[: self.queue_idx]
+ else:
+ E_mem = self.embedding_memory
+ L_mem = self.label_memory
+
+ indices_tuple = self.create_indices_tuple(
+ embeddings,
+ labels,
+ E_mem,
+ L_mem,
+ indices_tuple,
+ do_remove_self_comparisons,
+ )
+ loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem)
+ return loss
+
+ def add_to_memory(self, embeddings, labels, batch_size):
+ self.curr_batch_idx = (
+ torch.arange(
+ self.queue_idx, self.queue_idx + batch_size, device=labels.device
+ )
+ % self.memory_size
+ )
+ self.embedding_memory[self.curr_batch_idx] = embeddings.detach()
+ self.label_memory[self.curr_batch_idx] = labels.detach()
+ prev_queue_idx = self.queue_idx
+ self.queue_idx = (self.queue_idx + batch_size) % self.memory_size
+ if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx):
+ self.has_been_filled = True
+
+ def create_indices_tuple(
+ self,
+ embeddings,
+ labels,
+ E_mem,
+ L_mem,
+ input_indices_tuple,
+ do_remove_self_comparisons,
+ ):
+ if self.miner:
+ indices_tuple = self.miner(embeddings, labels, E_mem, L_mem)
+ else:
+ indices_tuple = get_all_pairs_indices(labels, L_mem)
+
+ if do_remove_self_comparisons:
+ indices_tuple = remove_self_comparisons(
+ indices_tuple, self.curr_batch_idx, self.memory_size
+ )
+
+ if input_indices_tuple is not None:
+ if len(input_indices_tuple) == 3 and len(indices_tuple) == 4:
+ input_indices_tuple = convert_to_pairs(input_indices_tuple, labels)
+ elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3:
+ input_indices_tuple = convert_to_triplets(
+ input_indices_tuple, labels
+ )
+ indices_tuple = c_f.concatenate_indices_tuples(
+ indices_tuple, input_indices_tuple
+ )
+
+ return indices_tuple
+
+ def reset_queue(self):
+ self.register_buffer(
+ "embedding_memory", torch.zeros(self.memory_size, self.embedding_size)
+ )
+ self.register_buffer(
+ "label_memory", torch.zeros(self.memory_size, self.num_classes)
+ )
+ self.has_been_filled = False
+ self.queue_idx = 0
+
+# =========================================================================
+
+# compute jaccard similarity
+def jaccard(labels, ref_labels=None):
+ if ref_labels is None:
+ ref_labels = labels
+
+ labels1 = labels.float()
+ labels2 = ref_labels.float()
+
+ # compute jaccard similarity
+ # jaccard = intersection / union
+ labels1_union = labels1.sum(-1)
+ labels2_union = labels2.sum(-1)
+ union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0)
+ intersection = torch.mm(labels1, labels2.T)
+ jaccard_matrix = intersection / (union - intersection)
+
+ # return indices of jaccard similarity above threshold
+ return jaccard_matrix
+
+# ====== methods below are overriden for adaptability to multi-supcon ======
+
+# use jaccard similarity to get matches
+def get_matches_and_diffs(labels, ref_labels=None, threshold=0.3):
+ if ref_labels is None:
+ ref_labels = labels
+ jaccard_matrix = jaccard(labels, ref_labels)
+ matches = torch.where(jaccard_matrix > threshold, 1, 0)
+ diffs = matches ^ 1
+ if ref_labels is labels:
+ matches.fill_diagonal_(0)
+ return matches, diffs, jaccard_matrix
+
+def check_shapes_multilabels(embeddings, labels):
+ if labels is not None and embeddings.shape[0] != labels.shape[0]:
+ raise ValueError("Number of embeddings must equal number of labels")
+ if labels is not None and labels.ndim != 2:
+ raise ValueError("labels must be a 1D tensor of shape (batch_size,)")
+
+
+def set_ref_emb(embeddings, labels, ref_emb, ref_labels):
+ if ref_emb is None:
+ ref_emb, ref_labels = embeddings, labels
+ check_shapes_multilabels(ref_emb, ref_labels)
+ return ref_emb, ref_labels
+
+
+def convert_to_pairs(indices_tuple, labels, ref_labels=None, threshold=0.3):
+ """
+ This returns anchor-positive and anchor-negative indices,
+ regardless of what the input indices_tuple is
+ Args:
+ indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
+ within a batch
+ labels: a tensor which has the label for each element in a batch
+ """
+ if indices_tuple is None:
+ return get_all_pairs_indices(labels, ref_labels, threshold=threshold)
+ elif len(indices_tuple) == 5:
+ return indices_tuple
+ else:
+ a, p, n, jaccard_mat = indices_tuple
+ return a, p, a, n,jaccard_mat
+
+
+def get_all_pairs_indices(labels, ref_labels=None, threshold=0.3):
+ """
+ Given a tensor of labels, this will return 4 tensors.
+ The first 2 tensors are the indices which form all positive pairs
+ The second 2 tensors are the indices which form all negative pairs
+ """
+ matches, diffs, multi_val = get_matches_and_diffs(labels, ref_labels, threshold=threshold)
+ a1_idx, p_idx = torch.where(matches)
+ a2_idx, n_idx = torch.where(diffs)
+ return a1_idx, p_idx, a2_idx, n_idx, multi_val
+
+
+def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100):
+ """
+ This returns anchor-positive-negative triplets
+ regardless of what the input indices_tuple is
+ """
+ if indices_tuple is None:
+ if t_per_anchor == "all":
+ return get_all_triplets_indices(labels, ref_labels)
+ else:
+ return lmu.get_random_triplet_indices(
+ labels, ref_labels, t_per_anchor=t_per_anchor
+ )
+ elif len(indices_tuple) == 3:
+ return indices_tuple
+ else:
+ a1, p, a2, n = indices_tuple
+ p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2)
+ return a1[p_idx], p[p_idx], n[n_idx]
+
+
+def get_all_triplets_indices(labels, ref_labels=None):
+ matches, diffs = get_matches_and_diffs(labels, ref_labels)
+ triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
+ return torch.where(triplets)
+
+
+def remove_self_comparisons(
+ indices_tuple, curr_batch_idx, ref_size, ref_is_subset=False
+):
+ # remove self-comparisons
+ assert len(indices_tuple) in [4, 5]
+ s, e = curr_batch_idx[0], curr_batch_idx[-1]
+ if len(indices_tuple) == 4:
+ a, p, n, jaccard_mat = indices_tuple
+ keep_mask = lmu.not_self_comparisons(
+ a, p, s, e, curr_batch_idx, ref_size, ref_is_subset
+ )
+ a = a[keep_mask]
+ p = p[keep_mask]
+ n = n[keep_mask]
+ assert len(a) == len(p) == len(n)
+ return a, p, n, jaccard_mat
+ elif len(indices_tuple) == 5:
+ a1, p, a2, n, jaccard_mat = indices_tuple
+ keep_mask = lmu.not_self_comparisons(
+ a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset
+ )
+ a1 = a1[keep_mask]
+ p = p[keep_mask]
+ assert len(a1) == len(p)
+ assert len(a2) == len(n)
+ return a1, p, a2, n, jaccard_mat
+
+# =========================================================================
\ No newline at end of file
diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py
new file mode 100644
index 00000000..c0584448
--- /dev/null
+++ b/tests/losses/test_multilabel_supcon_loss.py
@@ -0,0 +1,126 @@
+import unittest
+
+import torch
+import numpy as np
+
+from pytorch_metric_learning.losses import (
+ MultiSupConLoss,
+ CrossBatchMemory4MultiLabel
+)
+
+from ..zzz_testing_utils.testing_utils import angle_to_coord
+
+from .. import TEST_DEVICE, TEST_DTYPES
+class TestMultiSupConLoss(unittest.TestCase):
+ def __init__(self, methodName: str = "runTest") -> None:
+ super().__init__(methodName)
+ self.n_cls = 3
+ self.n_samples = 4
+ self.n_dim = 3
+ self.n_batchs = 10
+ self.xbm_max_size = 1024
+
+ # multi_supcon
+ self.loss_func = MultiSupConLoss(
+ num_classes=self.n_cls,
+ temperature=0.07,
+ threshold=0.3)
+
+ # xbm
+ self.xbm_loss_func = CrossBatchMemory4MultiLabel(
+ self.loss_func,
+ self.n_dim,
+ memory_size=self.xbm_max_size)
+ # test cases
+ self.embeddings = torch.tensor([[0.1, 0.3, 0.1],
+ [0.23, -0.2, -0.1],
+ [0.1, -0.16, 0.1],
+ [0.13, -0.13, 0.2]])
+ self.labels = torch.tensor([[1,0,1], [1,0,0], [0,1,1], [0,1,0]])
+
+ # the gt values are obtained by running the code
+ # (https://github.com/WolodjaZ/MultiSupContrast/blob/main/losses.py)
+
+ # multi_supcon test cases
+ self.test_multisupcon_val_gt = {
+ torch.float16: 3.2836,
+ torch.float32: 3.2874,
+ torch.float64: 3.2874,
+ }
+ # xbm test cases
+ self.test_xbm_multisupcon_val_gt = {
+ torch.float16: [3.2836, 4.3792, 4.4588, 4.5741, 4.6831, 4.7809, 4.8682, 4.9465, 5.0174, 5.0819],
+ torch.float32: [3.2874, 4.3779, 4.4577, 4.5730, 4.6820, 4.7798, 4.8671, 4.9455, 5.0163, 5.0808],
+ torch.float64: [3.2874, 4.3779, 4.4577, 4.5730, 4.6820, 4.7798, 4.8671, 4.9455, 5.0163, 5.0808,]
+ }
+
+
+ def test_multisupcon_val(self):
+ for dtype in TEST_DTYPES:
+ for device in ["cpu", "cuda"]:
+ # skip float16 on cpu
+ if device == "cpu" and dtype == torch.float16:
+ continue
+ embedding = self.embeddings.to(device).to(dtype)
+ label = self.labels.to(device).to(dtype)
+ loss = self.loss_func(embedding, label)
+ loss = loss.to("cpu")
+ self.assertTrue(np.isclose(
+ loss.item(),
+ self.test_multisupcon_val_gt[dtype],
+ atol=1e-2 if dtype == torch.float16 else 1e-4))
+
+
+ def test_xbm_multisupcon_val(self):
+ # test xbm with scatter labels
+ for dtype in TEST_DTYPES:
+ for device in ["cpu", "cuda"]:
+ # skip float16 on cpu
+ if device == "cpu" and dtype == torch.float16:
+ continue
+ self.xbm_loss_func.reset_queue()
+ for b in range(self.n_batchs):
+ embedding = self.embeddings.to(device).to(dtype)
+ label = self.labels.to(device).to(dtype)
+ loss = self.xbm_loss_func(embedding, label)
+ loss = loss.to("cpu")
+ print(loss, self.test_xbm_multisupcon_val_gt[dtype][b], dtype)
+ self.assertTrue(np.isclose(
+ loss.item(),
+ self.test_xbm_multisupcon_val_gt[dtype][b],
+ atol=1e-2 if dtype == torch.float16 else 1e-4))
+
+
+ def test_with_no_valid_pairs(self):
+ for dtype in TEST_DTYPES:
+ embedding_angles = [0]
+ embeddings = torch.tensor(
+ [angle_to_coord(a) for a in embedding_angles],
+ requires_grad=True,
+ dtype=dtype,
+ ).to(
+ TEST_DEVICE
+ ) # 2D embeddings
+ labels = torch.LongTensor([[0]])
+ loss = self.loss_func(embeddings, labels)
+ loss.backward()
+ self.assertEqual(loss, 0)
+
+
+ def test_backward(self):
+ for dtype in TEST_DTYPES:
+ embedding_angles = list(range(0, 180, 20))[:4]
+ embeddings = torch.tensor(
+ [angle_to_coord(a) for a in embedding_angles],
+ requires_grad=True,
+ dtype=dtype,
+ ).to(
+ TEST_DEVICE
+ ) # 2D embeddings
+ labels = torch.LongTensor([[0, 0, 1, 0, 1, 0, 0],
+ [1, 0, 1, 1, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1, 0, 1]]).to(TEST_DEVICE)
+
+ loss = self.loss_func(embeddings, labels)
+ loss.backward()
\ No newline at end of file