Python源码示例:torch.histc()
示例1
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are NDarray, output 4D, target 3D
# the category -1 is ignored class, typically for background / boundary
mini = 1
maxi = nclass
nbins = nclass
predict = torch.argmax(output, 1) + 1
target = target.float() + 1
predict = predict.float() * (target > 0).float()
intersection = predict * (predict == target).float()
# areas of intersection and union
area_inter = torch.histc(intersection, bins=nbins, min=mini, max=maxi)
area_pred = torch.histc(predict, bins=nbins, min=mini, max=maxi)
area_lab = torch.histc(target, bins=nbins, min=mini, max=maxi)
area_union = area_pred + area_lab - area_inter
assert torch.sum(area_inter > area_union).item() == 0, \
"Intersection area should be smaller than Union area"
return area_inter.float(), area_union.float()
示例2
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are numpy array, output 4D, target 3D
mini = 1
maxi = nclass
nbins = nclass
predict = torch.argmax(output, 1) + 1
target = target.float() + 1
predict = predict.float() * (target > 0).float()
intersection = predict * (predict == target).float()
# areas of intersection and union
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
area_union = area_pred + area_lab - area_inter
assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
return area_inter.float(), area_union.float()
示例3
def get_selabel_vector(target, nclass):
r"""Get SE-Loss Label in a batch
Args:
predict: input 4D tensor
target: label 3D tensor (BxHxW)
nclass: number of categories (int)
Output:
2D tensor (BxnClass)
"""
batch = target.size(0)
tvect = torch.zeros(batch, nclass)
for i in range(batch):
hist = torch.histc(target[i].data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
示例4
def cal_hist(image):
"""
cal cumulative hist for channel list
"""
hists = []
for i in range(0, 3):
channel = image[i]
# channel = image[i, :, :]
channel = torch.from_numpy(channel)
# hist, _ = np.histogram(channel, bins=256, range=(0,255))
hist = torch.histc(channel, bins=256, min=0, max=256)
hist = hist.numpy()
# refHist=hist.view(256,1)
sum = hist.sum()
pdf = [v / sum for v in hist]
for i in range(1, 256):
pdf[i] = pdf[i - 1] + pdf[i]
hists.append(pdf)
return hists
示例5
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are numpy array, output 4D, target 3D
mini = 1
maxi = nclass
nbins = nclass
predict = torch.argmax(output, 1) + 1
target = target.float() + 1
predict = predict.float() * (target > 0).float()
intersection = predict * (predict == target).float()
# areas of intersection and union
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
area_union = area_pred + area_lab - area_inter
assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
return area_inter.float(), area_union.float()
示例6
def intersectionAndUnion(batch_data, pred, numClass):
(imgs, segs, infos) = batch_data
_, preds = torch.max(pred.data.cpu(), dim=1)
# compute area intersection
intersect = preds.clone()
intersect[torch.ne(preds, segs)] = -1
area_intersect = torch.histc(intersect.float(),
bins=numClass,
min=0,
max=numClass - 1)
# compute area union:
preds[torch.lt(segs, 0)] = -1
area_pred = torch.histc(preds.float(),
bins=numClass,
min=0,
max=numClass - 1)
area_lab = torch.histc(segs.float(),
bins=numClass,
min=0,
max=numClass - 1)
area_union = area_pred + area_lab - area_intersect
return area_intersect, area_union
示例7
def intersection_union(gt, pred, correct, n_class):
intersect = pred * correct
area_intersect = torch.histc(intersect, bins=n_class, min=1, max=n_class)
area_pred = torch.histc(pred, bins=n_class, min=1, max=n_class)
area_gt = torch.histc(gt, bins=n_class, min=1, max=n_class)
# intersect = intersect.detach().to('cpu').numpy()
# pred = pred.detach().to('cpu').numpy()
# gt = gt.detach().to('cpu').numpy()
# area_intersect, _ = np.histogram(intersect, bins=n_class, range=(1, n_class))
# area_pred, _ = np.histogram(pred, bins=n_class, range=(1, n_class))
# area_gt, _ = np.histogram(gt, bins=n_class, range=(1, n_class))
area_union = area_pred + area_gt - area_intersect
return area_intersect, area_union
示例8
def forward(self,feat_t0,feat_t1,ground_truth):
n,c,h,w = feat_t0.data.shape
out_t0_rz = torch.transpose(feat_t0.view(c,h*w),1,0)
out_t1_rz = torch.transpose(feat_t1.view(c,h*w),1,0)
gt_np = ground_truth.view(h * w).data.cpu().numpy()
#### inspired by Source code from Histogram loss ###
### get all pos in positive pairs and negative pairs ###
pos_inds_np,neg_inds_np = np.squeeze(np.where(gt_np == 0), 1),np.squeeze(np.where(gt_np !=0),1)
pos_size,neg_size = pos_inds_np.shape[0],neg_inds_np.shape[0]
pos_inds,neg_inds = torch.from_numpy(pos_inds_np).cuda(),torch.from_numpy(neg_inds_np).cuda()
### get similarities(l2 distance) for all position ###
distance = torch.squeeze(self.various_distance(out_t0_rz,out_t1_rz),dim=1)
### build similarity histogram of positive pairs and negative pairs ###
pos_dist_ls,neg_dist_ls = distance[pos_inds],distance[neg_inds]
pos_dist_ls_t,neg_dist_ls_t = torch.from_numpy(pos_dist_ls.data.cpu().numpy()),torch.from_numpy(neg_dist_ls.data.cpu().numpy())
hist_pos = Variable(torch.histc(pos_dist_ls_t,bins=100,min=0,max=1)/pos_size,requires_grad=True)
hist_neg = Variable(torch.histc(neg_dist_ls_t,bins=100,min=0,max=1)/neg_size,requires_grad=True)
loss = self.distance(hist_pos,hist_neg)
return loss
示例9
def forward(self, y, batch):
if self.use_cuda:
hist = Variable(
torch.histc(y.cpu().data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1
).cuda()
else:
hist = Variable(
torch.histc(y.data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1
)
centers_count = hist.index_select(0, y.long()) # 1 + how many examples of y[i]-th class
batch_size = batch.size()[0]
embeddings = batch.view(batch_size, -1)
assert embeddings.size()[1] == self.embedding_size
centers_pred = self.centers.index_select(0, y.long())
diff = embeddings - centers_pred
loss = 1 / 2.0 * (diff.pow(2).sum(1) / centers_count).sum()
return loss
示例10
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are numpy array, output 4D, target 3D
mini = 1
maxi = nclass
nbins = nclass
predict = torch.argmax(output, 1) + 1
target = target.float() + 1
predict = predict.float() * (target > 0).float()
intersection = predict * (predict == target).float()
# areas of intersection and union
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
area_union = area_pred + area_lab - area_inter
assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
return area_inter.float(), area_union.float()
示例11
def accuracy(pred_cls, true_cls, nclass=79):
"""
Function to calculate accuracy (TP/(TP + FP + TN + FN)
:param pytorch.Tensor pred_cls: network prediction (categorical)
:param pytorch.Tensor true_cls: ground truth (categorical)
:param int nclass: number of classes
:return:
"""
positive = torch.histc(true_cls.cpu().float(), bins=nclass, min=0, max=nclass, out=None)
per_cls_counts = []
tpos = []
for i in range(1, nclass):
true_positive = ((pred_cls == i).float() + (true_cls == i).float()).eq(2).sum().item()
tpos.append(true_positive)
per_cls_counts.append(positive[i])
return np.array(tpos), np.array(per_cls_counts)
##
# Plotting functions
##
示例12
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are numpy array, output 4D, target 3D
mini = 1
maxi = nclass
nbins = nclass
predict = torch.argmax(output, 1) + 1
target = target.float() + 1
predict = predict.float() * (target > 0).float()
intersection = predict * (predict == target).float()
# areas of intersection and union
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
area_union = area_pred + area_lab - area_inter
assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
return area_inter.float(), area_union.float()
示例13
def forward(self, inputs, target):
"""
:param inputs: predictions (N, C, H, W)
:param target: target distribution (N, C, H, W)
:return: loss with image-wise weighting factor
"""
assert inputs.size() == target.size()
mask = (target != self.ignore_index)
_, argpred = torch.max(inputs, 1)
weights = []
batch_size = inputs.size(0)
for i in range(batch_size):
hist = torch.histc(argpred[i].cpu().data.float(),
bins=self.num_class, min=0,
max=self.num_class-1).float()
weight = (1/torch.max(torch.pow(hist, self.ratio)*torch.pow(hist.sum(), 1-self.ratio), torch.ones(1))).to(argpred.device)[argpred[i]].detach()
weights.append(weight)
weights = torch.stack(weights, dim=0)
log_likelihood = F.log_softmax(inputs, dim=1)
loss = torch.sum((torch.mul(-log_likelihood, target)*weights)[mask]) / (batch_size*self.num_class)
return loss
示例14
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass - 1)
vect = hist > 0
tvect[i] = vect
return tvect
示例15
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
示例16
def energy_spectrum(vel):
"""
Compute energy spectrum given a velocity field
:param vel: tensor of shape (N, 3, res, res, res)
:return spec: tensor of shape(N, res/2)
:return k: tensor of shape (res/2,), frequencies corresponding to spec
"""
device = vel.device
res = vel.shape[-2:]
assert(res[0] == res[1])
r = res[0]
k_end = int(r/2)
vel_ = pad_rfft3(vel, onesided=False) # (N, 3, res, res, res, 2)
uu_ = (torch.norm(vel_, dim=-1) / r**3)**2
e_ = torch.sum(uu_, dim=1) # (N, res, res, res)
k = fftfreqs(res).to(device) # (3, res, res, res)
rad = torch.norm(k, dim=0) # (res, res, res)
k_bin = torch.arange(k_end, device=device).float()+1
bins = torch.zeros(k_end+1).to(device)
bins[1:-1] = (k_bin[1:]+k_bin[:-1])/2
bins[-1] = k_bin[-1]
bins = bins.unsqueeze(0)
bins[1:] += 1e-3
inds = searchsorted(bins, rad.flatten().unsqueeze(0)).squeeze().int()
# bincount = torch.histc(inds.cpu(), bins=bins.shape[1]+1).to(device)
bincount = torch.bincount(inds)
asort = torch.argsort(inds.squeeze())
sorted_e_ = e_.view(e_.shape[0], -1)[:, asort]
csum_e_ = torch.cumsum(sorted_e_, dim=1)
binloc = torch.cumsum(bincount, dim=0).long()-1
spec_ = csum_e_[:,binloc[1:]] - csum_e_[:,binloc[:-1]]
spec_ = spec_[:, :-1]
spec_ = spec_ * 2 * np.pi * (k_bin.float()**2) / bincount[1:-1].float()
return spec_, k_bin
##################### COMPUTE STATS ###########################
示例17
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass - 1)
vect = hist > 0
tvect[i] = vect
return tvect
# TODO: optim function
示例18
def get_doc_freqs_t(cnts):
"""
Return word --> # of docs it appears in (torch version).
"""
return torch.histc(
cnts._indices()[0].float(), bins=cnts.size(0), min=0, max=cnts.size(0)
)
示例19
def get_doc_freqs_t(cnts):
"""Return word --> # of docs it appears in (torch version)."""
return torch.histc(
cnts._indices()[0].float(), bins=cnts.size(0), min=0, max=cnts.size(0)
)
示例20
def batch_intersection_union(output, target, num_class):
_, predict = torch.max(output, 1)
predict = predict + 1
target = target + 1
predict = predict * (target > 0).long()
intersection = predict * (predict == target).long()
area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)
area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)
area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1)
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area"
return area_inter.cpu().numpy(), area_union.cpu().numpy()
示例21
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert (output.dim() in [1, 2, 3])
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output[target == ignore_index] = ignore_index
intersection = output[output == target]
# https://github.com/pytorch/pytorch/issues/1382
area_intersection = torch.histc(intersection.float().cpu(), bins=K, min=0, max=K-1)
area_output = torch.histc(output.float().cpu(), bins=K, min=0, max=K-1)
area_target = torch.histc(target.float().cpu(), bins=K, min=0, max=K-1)
area_union = area_output + area_target - area_intersection
return area_intersection.cuda(), area_union.cuda(), area_target.cuda()
示例22
def calculate_histogram(self, abstract_features_1, abstract_features_2):
"""
Calculate histogram from similarity matrix.
:param abstract_features_1: Feature matrix for graph 1.
:param abstract_features_2: Feature matrix for graph 2.
:return hist: Histsogram of similarity scores.
"""
scores = torch.mm(abstract_features_1, abstract_features_2).detach()
scores = scores.view(-1, 1)
hist = torch.histc(scores, bins=self.args.bins)
hist = hist/torch.sum(hist)
hist = hist.view(1, -1)
return hist
示例23
def put_histogram(self, hist_name, hist_tensor, bins=1000):
"""
Create a histogram from a tensor.
Args:
hist_name (str): The name of the histogram to put into tensorboard.
hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
into a histogram.
bins (int): Number of histogram bins.
"""
ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
# Create a histogram with PyTorch
hist_counts = torch.histc(hist_tensor, bins=bins)
hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
# Parameter for the add_histogram_raw function of SummaryWriter
hist_params = dict(
tag=hist_name,
min=ht_min,
max=ht_max,
num=len(hist_tensor),
sum=float(hist_tensor.sum()),
sum_squares=float(torch.sum(hist_tensor ** 2)),
bucket_limits=hist_edges[1:].tolist(),
bucket_counts=hist_counts.tolist(),
global_step=self._iter,
)
self._histograms.append(hist_params)
示例24
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert (output.dim() in [1, 2, 3])
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output[target == ignore_index] = ignore_index
intersection = output[output == target]
area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
area_output = torch.histc(output, bins=K, min=0, max=K-1)
area_target = torch.histc(target, bins=K, min=0, max=K-1)
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
示例25
def get_flow_histogram(flow):
flow_magnitude = ((flow[..., 0] ** 2 + flow[..., 1] ** 2) ** 0.5).flatten()
flow_magnitude[flow_magnitude > 99] = 99
return torch.histc(flow_magnitude, min=0, max=100) / len(flow_magnitude)
示例26
def get_wue_align(self, prev_output_tokens,
encoder_out, incremental_state=None):
# source embeddings
src_emb = encoder_out['encoder_out'] # B, Ts, ds
# target embeddings:
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
# embed the last target token
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# Build the full grid
tgt_emb = self.embed_scale * self.embed_tokens(prev_output_tokens)
if positions is not None:
tgt_emb += positions
tgt_emb = self.ln(tgt_emb)
tgt_emb = self.embedding_dropout(tgt_emb)
src_length = src_emb.size(1)
tgt_length = tgt_emb.size(1)
# build 2d "image" of embeddings
src_emb = _expand(src_emb, 1, tgt_length) # B, Tt, Ts, ds
tgt_emb = _expand(tgt_emb, 2, src_length) # B, Tt, Ts, dt
x = torch.cat((src_emb, tgt_emb), dim=3) # B, Tt, Ts, C=ds+dt
x = self.input_dropout(x)
# pass through dense convolutional layers
x = self.net(x, incremental_state) # B, Tt, Ts, C
x, indices = x.max(dim=2) # B, Tt, C
# only works for N=1
counts = [torch.histc(indices[:, i], bins=src_length, min=0, max=src_length-1) for i in range(tgt_length)]
counts = [c.float()/torch.sum(c) for c in counts]
align = torch.stack(counts, dim=0).unsqueeze(0) # 1, Tt, Ts
return [align]
示例27
def forward(self, pred, prob, label=None):
"""
:param pred: predictions (N, C, H, W)
:param prob: probability of pred (N, C, H, W)
:param label(optional): the map for counting label numbers (N, C, H, W)
:return: maximum squares loss with image-wise weighting factor
"""
# prob -= 0.5
N, C, H, W = prob.size()
mask = (prob != self.ignore_index)
maxpred, argpred = torch.max(prob, 1)
mask_arg = (maxpred != self.ignore_index)
argpred = torch.where(mask_arg, argpred, torch.ones(1).to(prob.device, dtype=torch.long)*self.ignore_index)
if label is None:
label = argpred
weights = []
batch_size = prob.size(0)
for i in range(batch_size):
hist = torch.histc(label[i].cpu().data.float(),
bins=self.num_class+1, min=-1,
max=self.num_class-1).float()
hist = hist[1:]
weight = (1/torch.max(torch.pow(hist, self.ratio)*torch.pow(hist.sum(), 1-self.ratio), torch.ones(1))).to(argpred.device)[argpred[i]].detach()
weights.append(weight)
weights = torch.stack(weights, dim=0)
mask = mask_arg.unsqueeze(1).expand_as(prob)
prior = torch.mean(prob, (2,3), True).detach()
loss = -torch.sum((torch.pow(prob, 2)*weights)[mask]) / (batch_size*self.num_class)
return loss
示例28
def _findcluster(self):
"""Finds a cluster to output."""
threshold = None
# Keep looping until we find a cluster
while threshold is None:
# If on GPU, we need to take next seed which has not already been clusted out.
# if not, clustered points have been removed, so we can just take next seed
if self.CUDA:
self.seed = (self.seed + 1) % len(self.matrix)
while self.kept_mask[self.seed] == False:
self.seed = (self.seed + 1) % len(self.matrix)
else:
self.seed = (self.seed + 1) % len(self.matrix)
medoid, distances = _wander_medoid(self.matrix, self.kept_mask, self.seed, self.MAXSTEPS, self.RNG, self.CUDA)
# We need to make a histogram of only the unclustered distances - when run on GPU
# these have not been removed and we must use the kept_mask
if self.CUDA:
_torch.histc(distances[self.kept_mask], len(self.histogram), 0, _XMAX, out=self.histogram)
else:
_torch.histc(distances, len(self.histogram), 0, _XMAX, out=self.histogram)
self.histogram[0] -= 1 # Remove distance to self
threshold, success = _find_threshold(self.histogram, self.peak_valley_ratio, self.CUDA)
# If success is not None, either threshold detection failed or succeded.
if success is not None:
# Keep accurately track of successes if we exceed maxlen
if len(self.attempts) == self.attempts.maxlen:
self.successes -= self.attempts.popleft()
# Add the current success to count
self.successes += success
self.attempts.append(success)
# If less than minsuccesses of the last maxlen attempts were successful,
# we relax the clustering criteria and reset counting successes.
if len(self.attempts) == self.attempts.maxlen and self.successes < self.MINSUCCESSES:
self.peak_valley_ratio += 0.1
self.attempts.clear()
self.successes = 0
# These are the points of the final cluster AFTER establishing the threshold used
points = _smaller_indices(distances, self.kept_mask, threshold, self.CUDA)
isdefault = success is None and threshold == _DEFAULT_RADIUS and self.peak_valley_ratio > 0.55
cluster = Cluster(self.indices[medoid].item(), self.seed, self.indices[points].numpy(),
self.peak_valley_ratio,
threshold, isdefault, self.successes, len(self.attempts))
return cluster, medoid, points
示例29
def forward(self, query: Dict[str, torch.Tensor], document: Dict[str, torch.Tensor]) -> torch.Tensor:
# pylint: disable=arguments-differ
#
# prepare embedding tensors
# -------------------------------------------------------
# we assume 1 is the unknown token, 0 is padding - both need to be removed
if len(query["tokens"].shape) == 2: # (embedding lookup matrix)
# shape: (batch, query_max)
query_pad_oov_mask = (query["tokens"] > 1).float()
# shape: (batch, doc_max)
document_pad_oov_mask = (document["tokens"] > 1).float()
else: # == 3 (elmo characters per word)
# shape: (batch, query_max)
query_pad_oov_mask = (torch.sum(query["tokens"],2) > 0).float()
# shape: (batch, doc_max)
document_pad_oov_mask = (torch.sum(document["tokens"],2) > 0).float()
# shape: (batch, query_max,emb_dim)
query_embeddings = self.word_embeddings(query) * query_pad_oov_mask.unsqueeze(-1)
# shape: (batch, document_max,emb_dim)
document_embeddings = self.word_embeddings(document) * document_pad_oov_mask.unsqueeze(-1)
#
# similarity matrix
# -------------------------------------------------------
# create sim matrix
cosine_matrix = self.cosine_module.forward(query_embeddings, document_embeddings).cpu()
#
# histogram & classfifier
# ----------------------------------------------
histogram_tensor = torch.empty((cosine_matrix.shape[0],cosine_matrix.shape[1],self.bin_count))
for b in range(cosine_matrix.shape[0]):
for q in range(cosine_matrix.shape[1]):
histogram_tensor[b,q] = torch.histc(cosine_matrix[b,q], bins=self.bin_count, min=-1, max=1)
histogram_tensor = histogram_tensor.to(device=query_embeddings.device)
classified_matches_per_query = self.matching_classifier(torch.log1p(histogram_tensor)) # log1p is super important - lol just the opposite of knrm, does somebody understand the world??
#
# query gate
# ----------------------------------------------
query_gates_raw = self.query_gate(query_embeddings)
query_gates = self.query_softmax(query_gates_raw.squeeze(-1),query_pad_oov_mask).unsqueeze(-1)
#
# combine it all
# ----------------------------------------------
scores = torch.sum(classified_matches_per_query * query_gates,dim=1)
return scores