def get_query_x(self, Query_x, centroid_per_class, Query_y_labels):
Returns distance matrix from each Query image to each centroid.
centroid_matrix = self.get_centroid_matrix(
centroid_per_class, Query_y_labels)
Query_x = self.f(Query_x)
m = Query_x.size(0)
n = centroid_matrix.size(0)
# The below expressions expand both the matrices such that they become compatible to each other in order to caclulate L2 distance.
# Expanding centroid matrix to "m".
centroid_matrix = centroid_matrix.expand(
m, centroid_matrix.size(0), centroid_matrix.size(1))
Query_matrix = Query_x.expand(n, Query_x.size(0), Query_x.size(
1)).transpose(0, 1) # Expanding Query matrix "n" times
Qx = torch.pairwise_distance(centroid_matrix.transpose(
1, 2), Query_matrix.transpose(1, 2))
return Qx
def pdist2(x, y):
Compute distance between each pair of row vectors in x and y
x: tensor of shape n*p
y: tensor of shape m*p
dist: tensor of shape n*m
p = x.shape[1]
n = x.shape[0]
m = y.shape[0]
xtile =[x] * m, dim=1).view(-1, p)
ytile =[y] * n, dim=0)
dist = torch.pairwise_distance(xtile, ytile)
return dist.view(n, m)
def perceptual_features_reconstruction(list_attentions_a, list_attentions_b, factor=1.):
loss = 0.
for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)):
bs, c, w, h = a.shape
# a of shape (b, c, w, h) to (b, c * w * h)
a = a.view(bs, -1)
b = b.view(bs, -1)
a = F.normalize(a, p=2, dim=-1)
b = F.normalize(b, p=2, dim=-1)
layer_loss = (F.pairwise_distance(a, b, p=2)**2) / (c * w * h)
loss += torch.mean(layer_loss)
return factor * (loss / len(list_attentions_a))
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
# type: (Tensor, Tensor, float, float, bool) -> Tensor
See :class:`torch.nn.PairwiseDistance` for details
return torch.pairwise_distance(x1, x2, p, eps, keepdim)
def forward_gmmn(self, visual_features, semantic_features, class_id, words, metrics):
loss = mmd(real=visual_features, fake=semantic_features, **self.gmmn_config["mmd"])
if self.gmmn_config.get("old_mmd") and self._old_word_embeddings is not None:
old_unseen_limit = self._n_classes - self._task_size
if not self.gmmn_config["old_mmd"].get(
"apply_unseen", False
) and class_id >= old_unseen_limit:
return loss
with torch.no_grad():
old_semantic_features = self._old_word_embeddings(words)
factor = self.gmmn_config["old_mmd"]["factor"]
_type = self.gmmn_config["old_mmd"].get("type", "mmd")
if _type == "mmd":
old_loss = factor * mmd(
real=old_semantic_features, fake=semantic_features, **self.gmmn_config["mmd"]
elif _type == "kl":
old_loss = factor * F.kl_div(
semantic_features, old_semantic_features, reduction="batchmean"
elif _type == "l2":
old_loss = factor * torch.pairwise_distance(
semantic_features, old_semantic_features, p=2
elif _type == "cosine":
old_loss = factor * (
1 - torch.cosine_similarity(semantic_features, old_semantic_features)
raise ValueError(f"Unknown distillation: {_type}.")
if self.gmmn_config.get("scheduled"):
old_loss = old_loss * math.sqrt(self._n_classes / self._task_size)
metrics["old"] += old_loss.item()
return loss + old_loss
return loss
def mmd(x, y, sigmas=[1, 5, 10], normalize=False):
"""Maximum Mean Discrepancy with several Gaussian kernels."""
# Flatten:
x = x.view(x.shape[0], -1)
y = y.view(y.shape[0], -1)
if len(sigmas) == 0:
mean_dist = torch.mean(torch.pow(torch.pairwise_distance(x, y, p=2), 2))
factors = (-1 / (2 * mean_dist)).view(1, 1, 1)
factors = _get_mmd_factor(sigmas, x.device)
if normalize:
x = F.normalize(x, p=2, dim=1)
y = F.normalize(y, p=2, dim=1)
xx = torch.pairwise_distance(x, x, p=2)**2
yy = torch.pairwise_distance(y, y, p=2)**2
xy = torch.pairwise_distance(x, y, p=2)**2
k_xx, k_yy, k_xy = 0, 0, 0
div = 1 / (x.shape[1]**2)
k_xx = div * torch.exp(factors * xx).sum(0).squeeze()
k_yy = div * torch.exp(factors * yy).sum(0).squeeze()
k_xy = div * torch.exp(factors * xy).sum(0).squeeze()
mmd_sq = torch.sum(k_xx) - 2 * torch.sum(k_xy) + torch.sum(k_yy)
return torch.sqrt(mmd_sq)
def _test_distrib_integration(device):
import numpy as np
from ignite.engine import Engine
rank = idist.get_rank()
n_iters = 100
s = 50
offset = n_iters * s
y_true = torch.rand(offset * idist.get_world_size(), 10).to(device)
y_preds = torch.rand(offset * idist.get_world_size(), 10).to(device)
def update(engine, i):
return (
y_preds[i * s + offset * rank : (i + 1) * s + offset * rank, ...],
y_true[i * s + offset * rank : (i + 1) * s + offset * rank, ...],
engine = Engine(update)
m = MeanPairwiseDistance()
m.attach(engine, "mpwd")
data = list(range(n_iters)), max_epochs=1)
assert "mpwd" in engine.state.metrics
res = engine.state.metrics["mpwd"]
true_res = []
for i in range(n_iters * idist.get_world_size()):
y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps
true_res = np.array(true_res).ravel()
true_res = true_res.mean()
assert pytest.approx(res) == true_res