Python源码示例:torch.argmin()
示例1
def kmeans(input, n_clusters=16, tol=1e-6):
"""
TODO: check correctness
"""
indices = torch.Tensor(np.random.choice(input.size(-1), n_clusters))
values = input[:, :, indices]
while True:
dist = func.pairwise_distance(
input.unsqueeze(2).expand(-1, -1, values.size(2), input.size(2)).reshape(
input.size(0), input.size(1), input.size(2) * values.size(2)),
values.unsqueeze(3).expand(-1, -1, values.size(2), input.size(2)).reshape(
input.size(0), input.size(1), input.size(2) * values.size(2))
)
choice_cluster = torch.argmin(dist, dim=1)
old_values = values
values = input[choice_cluster.nonzeros()]
shift = (old_values - values).norm(dim=1)
if shift.max() ** 2 < tol:
break
return values
示例2
def forward(self, grammian):
"""Planar case solver, when Vi lies on the same plane
Args:
grammian: grammian matrix G[i, j] = [<Vi, Vj>], G is a nxn tensor
Returns:
sol: coefficients c = [c1, ... cn] that solves the min-norm problem
"""
vivj = grammian[self.ii_triu, self.jj_triu]
vivi = grammian[self.ii_triu, self.ii_triu]
vjvj = grammian[self.jj_triu, self.jj_triu]
gamma, cost = self.line_solver_vectorized(vivi, vivj, vjvj)
offset = torch.argmin(cost)
i_min, j_min = self.i_triu[offset], self.j_triu[offset]
sol = torch.zeros(self.n, device=grammian.device)
sol[i_min], sol[j_min] = gamma[offset], 1. - gamma[offset]
return sol
示例3
def forward(self, XS, YS, XQ, YQ):
'''
@param XS (support x): support_size x ebd_dim
@param YS (support y): support_size
@param XQ (support x): query_size x ebd_dim
@param YQ (support y): query_size
@return acc
@return None (a placeholder for loss)
'''
if self.args.nn_distance == 'l2':
dist = self._compute_l2(XS, XQ)
elif self.args.nn_distance == 'cos':
dist = self._compute_cos(XS, XQ)
else:
raise ValueError("nn_distance can only be l2 or cos.")
# 1-NearestNeighbour
nn_idx = torch.argmin(dist, dim=1)
pred = YS[nn_idx]
acc = torch.mean((pred == YQ).float()).item()
return acc, None
示例4
def forward(self, coord_volumes_batch, volumes_batch_pred, keypoints_gt, keypoints_binary_validity):
loss = 0.0
n_losses = 0
batch_size = volumes_batch_pred.shape[0]
for batch_i in range(batch_size):
coord_volume = coord_volumes_batch[batch_i]
keypoints_gt_i = keypoints_gt[batch_i]
coord_volume_unsq = coord_volume.unsqueeze(0)
keypoints_gt_i_unsq = keypoints_gt_i.unsqueeze(1).unsqueeze(1).unsqueeze(1)
dists = torch.sqrt(((coord_volume_unsq - keypoints_gt_i_unsq) ** 2).sum(-1))
dists = dists.view(dists.shape[0], -1)
min_indexes = torch.argmin(dists, dim=-1).detach().cpu().numpy()
min_indexes = np.stack(np.unravel_index(min_indexes, volumes_batch_pred.shape[-3:]), axis=1)
for joint_i, index in enumerate(min_indexes):
validity = keypoints_binary_validity[batch_i, joint_i]
loss += validity[0] * (-torch.log(volumes_batch_pred[batch_i, joint_i, index[0], index[1], index[2]] + 1e-6))
n_losses += 1
return loss / n_losses
示例5
def assign(self, points, distance='euclid', greedy=False):
# points = points.data
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
print('Distances:', distances[:3])
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例6
def assign(self, points, distance='euclid', greedy=False):
centroids = F.dropout(self.centroids, p=0.3)
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
print('Distances:', distances[:3])
if not greedy:
resp = - .5 * self.tau * distances - self.reduce_dim / 2 * math.log(2 * math.pi * self.tau) + torch.log(self.prior)
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例7
def assign(self, points, distance='euclid', greedy=False):
points = points.data
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# batch_counts = resp.sum(dim=0).view(-1).data
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例8
def assign(self, points, distance='euclid', greedy=False):
# points = points.data # the only diff from 16
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# batch_counts = resp.sum(dim=0).view(-1).data
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例9
def assign(self, points, distance='euclid', greedy=False):
points = points.data
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# batch_counts = resp.sum(dim=0).view(-1).data
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例10
def assign(self, points, distance='euclid', greedy=False):
# points = points.data # the only diff from 16
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# batch_counts = resp.sum(dim=0).view(-1).data
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例11
def calculate_partitions(partitions_count, cluster_partitions, types):
partition_distribution = torch.ones((partitions_count,
len(torch.unique(types))),
dtype=torch.long)
partition_assignments = torch.zeros(cluster_partitions.shape[0],
dtype=torch.long)
for i in torch.unique(cluster_partitions):
cluster_positions = (cluster_partitions == i).nonzero()
cluster_types = types[cluster_positions]
unique_types_in_cluster, type_count = torch.unique(cluster_types, return_counts=True)
tmp_distribution = partition_distribution.clone()
tmp_distribution[:, unique_types_in_cluster] += type_count
relative_distribution = partition_distribution.double() / tmp_distribution.double()
min_relative_distribution_group = torch.argmin(torch.sum(relative_distribution, dim=1))
partition_distribution[min_relative_distribution_group,
unique_types_in_cluster] += type_count
partition_assignments[cluster_positions] = min_relative_distribution_group
write_out("Loaded data into the following partitions")
write_out("[[ TM SP+TM SP Glob]")
write_out(partition_distribution - torch.ones(partition_distribution.shape,
dtype=torch.long))
return partition_assignments
示例12
def _nnef_argminmax_reduce(input, axes, argmin=False):
# type:(torch.Tensor, List[int], bool)->torch.Tensor
if len(axes) == 1:
return _nnef_generic_reduce(input=input, axes=axes, f=torch.argmin if argmin else torch.argmax)
else:
axes = sorted(axes)
consecutive_axes = list(range(axes[0], axes[0] + len(axes)))
if axes == consecutive_axes:
reshaped = nnef_reshape(input,
shape=(list(input.shape)[:axes[0]]
+ [-1]
+ list(input.shape[axes[0] + len(axes):])))
reduced = _nnef_generic_reduce(input=reshaped, axes=[axes[0]], f=torch.argmin if argmin else torch.argmax)
reshaped = nnef_reshape(reduced, shape=list(dim if axis not in axes else 1
for axis, dim in enumerate(input.shape)))
return reshaped
else:
raise utils.NNEFToolsException(
"{} is only implemented for consecutive axes.".format("argmin_reduce" if argmin else "argmax_reduce"))
示例13
def step(self, i):
"""
There are two standard steps for each iteration: expectation (E) and
minimization (M). The E-step (assignment) is performed with an exhaustive
search and the M-step (centroid computation) is performed with
the exact solution.
Args:
- i: step number
Remarks:
- The E-step heavily uses PyTorch broadcasting to speed up computations
and reduce the memory overhead
"""
# assignments (E-step)
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
n_empty_clusters = self.resolve_empty_clusters()
# centroids (M-step)
for k in range(self.n_centroids):
W_k = self.W[:, self.assignments == k] # (in_features x size_of_cluster_k)
self.centroids[k] = W_k.mean(dim=1) # (in_features)
# book-keeping
obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
self.objective.append(obj)
if self.verbose:
logging.info(
f"Iteration: {i},\t"
f"objective: {obj:.6f},\t"
f"resolved empty clusters: {n_empty_clusters}"
)
示例14
def resolve_empty_clusters(self):
"""
If one cluster is empty, the most populated cluster is split into
two clusters by shifting the respective centroids. This is done
iteratively for a fixed number of tentatives.
"""
# empty clusters
counts = Counter(map(lambda x: x.item(), self.assignments))
empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
n_empty_clusters = len(empty_clusters)
tentatives = 0
while len(empty_clusters) > 0:
# given an empty cluster, find most populated cluster and split it into two
k = random.choice(list(empty_clusters))
m = counts.most_common(1)[0][0]
e = torch.randn_like(self.centroids[m]) * self.eps
self.centroids[k] = self.centroids[m].clone()
self.centroids[k] += e
self.centroids[m] -= e
# recompute assignments
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
# check for empty clusters
counts = Counter(map(lambda x: x.item(), self.assignments))
empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
# increment tentatives
if tentatives == self.max_tentatives:
logging.info(
f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
)
raise EmptyClusterResolveError
tentatives += 1
return n_empty_clusters
示例15
def assign(self):
"""
Assigns each column of W to its closest centroid, thus essentially
performing the E-step in train().
Remarks:
- The function must be called after train() or after loading
centroids using self.load(), otherwise it will return empty tensors
"""
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
示例16
def append(self, new_k, new_v):
min_idx = torch.argmin(self.weight_buff).item()
self.keys[min_idx, :] = new_k
self.values[min_idx, :] = new_v
self.weight_buff[min_idx] = torch.mean(self.weight_buff)
示例17
def append(self, new_k, new_v):
"""
:param new_k: expecting a vector of dimensionality [Num Key Chan]
:param new_v: expecting a vector of dimensionality [Num Value Chan]
:return:
"""
min_idx = torch.argmin(self.weight_buff).item()
self.keys[min_idx, :] = new_k
self.values[min_idx, :] = new_v
self.weight_buff[min_idx] = torch.mean(self.weight_buff)
示例18
def forward(self, vecs):
"""Computes grammian matrix G_{i,j} = (<v_i, v_j>)_{i,j}.
"""
if self.n_tasks == 1:
return vecs[0]
if self.n_tasks == 2:
v1v1 = torch.dot(vecs[0], vecs[0])
v1v2 = torch.dot(vecs[0], vecs[1])
v2v2 = torch.dot(vecs[1], vecs[1])
gamma = self.line_solver(v1v1, v1v2, v2v2)
return gamma * vecs[0] + (1. - gamma) * vecs[1]
self.sol.fill_(1. / self.n)
self.new_sol.copy_(self.sol)
torch.mm(vecs, vecs.t(), out=self.grammian)
for iter_count in range(self.MAX_ITER):
gram_dot_sol = torch.mv(self.grammian, self.sol)
t_iter = torch.argmin(gram_dot_sol)
v1v1 = torch.dot(self.sol, gram_dot_sol)
v1v2 = torch.dot(self.sol, self.grammian[:, t_iter])
v2v2 = self.grammian[t_iter, t_iter]
gamma = self.line_solver(v1v1, v1v2, v2v2)
self.new_sol *= gamma
self.new_sol[t_iter] += 1. - gamma
change = self.new_sol - self.sol
if torch.sum(torch.abs(change)) < self.STOP_CRIT:
return self.new_sol
self.sol.copy_(self.new_sol)
return self.sol
示例19
def _proj_val(x, set):
"""
Compute the projection from x onto the set given.
:param x: Input pytorch tensor.
:param set: Input pytorch vector used to perform the projection.
"""
x = x.repeat((set.size()[0],)+(1,)*len(x.size()))
x = x.permute(*(tuple(range(len(x.size())))[1:] +(0,) ))
x = torch.abs(x-set)
x = torch.argmin(x, dim=len(x.size())-1, keepdim=False)
return set[x]
示例20
def least_used_cuda_device() -> Generator:
"""Contextmanager for automatically selecting the cuda device
with the least allocated memory"""
mem_allocs = get_cuda_max_memory_allocations()
least_used_device = torch.argmin(mem_allocs).item()
with torch.cuda.device(least_used_device):
yield
示例21
def subgraph_filter(x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, args):
D = sqdist(x_atom_pos[:,:,:3], x_atom_pos[:,:,:3])
x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle = \
x_atom.clone().detach(), x_atom_pos.clone().detach(), x_bond.clone().detach(), x_bond_dist.clone().detach(), x_triplet.clone().detach(), x_triplet_angle.clone().detach()
bsz = x_atom.shape[0]
bonds_mask = torch.ones(bsz, x_bond.shape[1], 1).to(x_atom.device)
for mol_id in range(bsz):
if np.random.uniform(0,1) > args.cutout:
continue
assert not args.use_quad, "Quads are NOT cut out yet"
atom_dists = D[mol_id]
atoms = x_atom[mol_id, :, 0]
n_valid_atoms = (atoms > 0).sum().item()
if n_valid_atoms < 10:
continue
idx_to_drop = np.random.randint(n_valid_atoms-1)
dist_row = atom_dists[idx_to_drop]
neighbor_to_drop = torch.argmin((dist_row[dist_row>0])[:n_valid_atoms-1]).item()
if neighbor_to_drop >= idx_to_drop:
neighbor_to_drop += 1
x_atom[mol_id, idx_to_drop] = 0
x_atom[mol_id, neighbor_to_drop] = 0
x_atom_pos[mol_id, idx_to_drop] = 0
x_atom_pos[mol_id, neighbor_to_drop] = 0
bond_pos_to_drop = (x_bond[mol_id, :, 3] == idx_to_drop) | (x_bond[mol_id, :, 3] == neighbor_to_drop) \
| (x_bond[mol_id, :, 4] == idx_to_drop) | (x_bond[mol_id, :, 4] == neighbor_to_drop)
trip_pos_to_drop = (x_triplet[mol_id, :, 2] == idx_to_drop) | (x_triplet[mol_id, :, 2] == neighbor_to_drop) \
| (x_triplet[mol_id, :, 3] == idx_to_drop) | (x_triplet[mol_id, :, 3] == neighbor_to_drop) \
| (x_triplet[mol_id, :, 4] == idx_to_drop) | (x_triplet[mol_id, :, 4] == neighbor_to_drop)
x_bond[mol_id, bond_pos_to_drop] = 0
x_bond_dist[mol_id, bond_pos_to_drop] = 0
bonds_mask[mol_id, bond_pos_to_drop] = 0
x_triplet[mol_id, trip_pos_to_drop] = 0
x_triplet_angle[mol_id, trip_pos_to_drop] = 0
return x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, bonds_mask
示例22
def step(self, i):
"""
There are two standard steps for each iteration: expectation (E) and
minimization (M). The E-step (assignment) is performed with an exhaustive
search and the M-step (centroid computation) is performed with
the exact solution.
Args:
- i: step number
Remarks:
- The E-step heavily uses PyTorch broadcasting to speed up computations
and reduce the memory overhead
"""
# assignments (E-step)
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
n_empty_clusters = self.resolve_empty_clusters()
# centroids (M-step)
for k in range(self.n_centroids):
W_k = self.W[:, self.assignments == k] # (in_features x size_of_cluster_k)
self.centroids[k] = W_k.mean(dim=1) # (in_features)
# book-keeping
obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
self.objective.append(obj)
if self.verbose:
logging.info(
f"Iteration: {i},\t"
f"objective: {obj:.6f},\t"
f"resolved empty clusters: {n_empty_clusters}"
)
示例23
def resolve_empty_clusters(self):
"""
If one cluster is empty, the most populated cluster is split into
two clusters by shifting the respective centroids. This is done
iteratively for a fixed number of tentatives.
"""
# empty clusters
counts = Counter(map(lambda x: x.item(), self.assignments))
empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
n_empty_clusters = len(empty_clusters)
tentatives = 0
while len(empty_clusters) > 0:
# given an empty cluster, find most populated cluster and split it into two
k = random.choice(list(empty_clusters))
m = counts.most_common(1)[0][0]
e = torch.randn_like(self.centroids[m]) * self.eps
self.centroids[k] = self.centroids[m].clone()
self.centroids[k] += e
self.centroids[m] -= e
# recompute assignments
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
# check for empty clusters
counts = Counter(map(lambda x: x.item(), self.assignments))
empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
# increment tentatives
if tentatives == self.max_tentatives:
logging.info(
f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
)
raise EmptyClusterResolveError
tentatives += 1
return n_empty_clusters
示例24
def assign(self):
"""
Assigns each column of W to its closest centroid, thus essentially
performing the E-step in train().
Remarks:
- The function must be called after train() or after loading
centroids using self.load(), otherwise it will return empty tensors
"""
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
示例25
def unravel_index(tensor, cols):
"""
args:
tensor : 2D tensor, [nb, rows*cols]
cols : int
return 2D tensor nb * [rowIndex, colIndex]
"""
index = torch.argmin(tensor, dim=1).view(-1,1)
rIndex = index / cols
cIndex = index % cols
minRC = torch.cat([rIndex, cIndex], dim=1)
# print("minRC", minRC.shape, minRC)
return minRC
示例26
def _compute_basic_stats(data):
# compute on non-zero data:
data = data[data > 0]
out = [("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")]
if len(data) > 1:
out += [
("min/index", (torch.min(data).item(), torch.argmin(data).item())),
("max/index", (torch.max(data).item(), torch.argmax(data).item())),
("mean", torch.mean(data).item()),
("std", torch.std(data).item()),
]
return OrderedDict(out)
示例27
def nnef_argmax_reduce(input, axes):
# type:(torch.Tensor, List[int])->torch.Tensor
return _nnef_argminmax_reduce(input, axes, argmin=False)
示例28
def nnef_argmin_reduce(input, axes):
# type:(torch.Tensor, List[int])->torch.Tensor
return _nnef_argminmax_reduce(input, axes, argmin=True)
示例29
def forward(self, segmentation, prob, gt_instance, gt_plane_num):
"""
greedy matching
match segmentation with ground truth instance
:param segmentation: tensor with size (N, K)
:param prob: tensor with size (N, 1)
:param gt_instance: tensor with size (21, h, w)
:param gt_plane_num: int
:return: a (K, 1) long tensor indicate closest ground truth instance id, start from 0
"""
n, k = segmentation.size()
_, h, w = gt_instance.size()
assert (prob.size(0) == n and h*w == n)
# ingnore non planar region
gt_instance = gt_instance[:gt_plane_num, :, :].view(1, -1, h*w) # (1, gt_plane_num, h*w)
segmentation = segmentation.t().view(k, 1, h*w) # (k, 1, h*w)
# calculate instance wise cross entropy matrix (K, gt_plane_num)
gt_instance = gt_instance.type(torch.float32)
ce_loss = - (gt_instance * torch.log(segmentation + 1e-6) +
(1-gt_instance) * torch.log(1-segmentation + 1e-6)) # (k, gt_plane_num, k*w)
ce_loss = torch.mean(ce_loss, dim=2) # (k, gt_plane_num)
matching = torch.argmin(ce_loss, dim=1, keepdim=True)
return matching
示例30
def choose_best_match_batch(Rrois, gt_rois):
"""
choose best match representation of gt_rois for a Rrois
:param Rrois: (x_ctr, y_ctr, w, h, angle)
shape: (n, 5)
:param gt_rois: (x_ctr, y_ctr, w, h, angle)
shape: (n, 5)
:return: gt_roi_news: gt_roi with new representation
shape: (n, 5)
"""
# TODO: check the dimensions
Rroi_angles = Rrois[:, 4].unsqueeze(1)
gt_xs, gt_ys, gt_ws, gt_hs, gt_angles = copy.deepcopy(gt_rois[:, 0]), copy.deepcopy(gt_rois[:, 1]), \
copy.deepcopy(gt_rois[:, 2]), copy.deepcopy(gt_rois[:, 3]), \
copy.deepcopy(gt_rois[:, 4])
gt_angle_extent = torch.cat((gt_angles[:, np.newaxis], (gt_angles + np.pi/2.)[:, np.newaxis],
(gt_angles + np.pi)[:, np.newaxis], (gt_angles + np.pi * 3/2.)[:, np.newaxis]), 1)
dist = (Rroi_angles - gt_angle_extent) % (2 * np.pi)
dist = torch.min(dist, np.pi * 2 - dist)
min_index = torch.argmin(dist, 1)
gt_rois_extent0 = copy.deepcopy(gt_rois)
gt_rois_extent1 = torch.cat((gt_xs.unsqueeze(1), gt_ys.unsqueeze(1), \
gt_hs.unsqueeze(1), gt_ws.unsqueeze(1), gt_angles.unsqueeze(1) + np.pi/2.), 1)
gt_rois_extent2 = torch.cat((gt_xs.unsqueeze(1), gt_ys.unsqueeze(1), \
gt_ws.unsqueeze(1), gt_hs.unsqueeze(1), gt_angles.unsqueeze(1) + np.pi), 1)
gt_rois_extent3 = torch.cat((gt_xs.unsqueeze(1), gt_ys.unsqueeze(1), \
gt_hs.unsqueeze(1), gt_ws.unsqueeze(1), gt_angles.unsqueeze(1) + np.pi * 3/2.), 1)
gt_rois_extent = torch.cat((gt_rois_extent0.unsqueeze(1),
gt_rois_extent1.unsqueeze(1),
gt_rois_extent2.unsqueeze(1),
gt_rois_extent3.unsqueeze(1)), 1)
gt_rois_new = torch.zeros_like(gt_rois)
# TODO: add pool.map here
for curiter, index in enumerate(min_index):
gt_rois_new[curiter, :] = gt_rois_extent[curiter, index, :]
gt_rois_new[:, 4] = gt_rois_new[:, 4] % (2 * np.pi)
return gt_rois_new