Python源码示例:torch.log_softmax()
示例1
def forward(self, x):
# 先计算得到线性的那一部分
linear_part = self.linear(x)
# 计算交叉部分
interaction_part = 0.0
for i in range(self.fea_num):
for j in range(i + 1, self.fea_num):
v_ifj = self.v[i, self.field_map_dict[j], :, :]
v_jfi = self.v[j, self.field_map_dict[i], :, :]
xij = torch.unsqueeze(x[:, i] * x[:, j], dim=1)
v_ijji = torch.unsqueeze(torch.sum(v_ifj * v_jfi, dim=0), dim=0)
interaction_part += torch.mm(xij, v_ijji)
output = linear_part + interaction_part
output = torch.log_softmax(output, dim=1)
return output
示例2
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args):
epoch_loss = 0.0
for image, target, input_len, target_len in tqdm(data_loader):
image = image.to(device)
# print(target, target_len, input_len)
outputs = model(image.to(torch.float32)) # [B,N,C]
outputs = torch.log_softmax(outputs, dim=2)
outputs = outputs.permute([1, 0, 2]) # [N,B,C]
loss = criterion(outputs[:], target, input_len, target_len)
# 梯度更新
model.zero_grad()
loss.backward()
optimizer.step()
# 当前轮的loss
epoch_loss += loss.item() * image.size(0)
if np.isnan(loss.item()):
print(target, input_len, target_len)
epoch_loss = epoch_loss / len(data_loader.dataset)
# 打印日志,保存权重
print('Epoch: {}/{} loss: {:03f}'.format(epoch + 1, args.epochs, epoch_loss))
return epoch_loss
示例3
def forward(self, task_id, x, y, seq_len):
words_emb = self.embedding(x)
char_emb = self.char(x)
x = torch.cat([words_emb, char_emb], dim=-1)
x, _ = self.lstm(x, seq_len)
self.dropout(x)
logit = self.out[task_id[0]](x)
seq_mask = seq_len_to_mask(seq_len, x.size(1))
if self.crf is not None:
logit = torch.log_softmax(logit, dim=-1)
loss = self.crf[task_id[0]](logit, y, seq_mask).mean()
pred = self.crf[task_id[0]].viterbi_decode(logit, seq_mask)[0]
else:
loss = ce_loss(logit, y, seq_mask)
pred = torch.argmax(logit, dim=2)
return {"loss": loss, "pred": pred}
示例4
def distillation(logits_student, logits_teacher, ylens, temperature=5.0):
"""Compute cross entropy loss for knowledge distillation of sequence-to-sequence models.
Args:
logits_student (FloatTensor): `[B, T, vocab]`
logits_teacher (FloatTensor): `[B, T, vocab]`
ylens (IntTensor): `[B]`
temperature (float):
Returns:
loss_mean (FloatTensor): `[1]`
"""
bs, _, vocab = logits_student.size()
log_probs_student = torch.log_softmax(logits_student, dim=-1)
probs_teacher = torch.softmax(logits_teacher / temperature, dim=-1).data
loss = -torch.mul(probs_teacher, log_probs_student)
loss_mean = np.sum([loss[b, :ylens[b], :].sum() for b in range(bs)]) / ylens.sum()
return loss_mean
示例5
def kldiv_lsm_ctc(logits, ylens):
"""Compute KL divergence loss for label smoothing of CTC and Transducer models.
Args:
logits (FloatTensor): `[B, T, vocab]`
ylens (IntTensor): `[B]`
Returns:
loss_mean (FloatTensor): `[1]`
"""
bs, _, vocab = logits.size()
log_uniform = logits.new_zeros(logits.size()).fill_(math.log(1 / (vocab - 1)))
probs = torch.softmax(logits, dim=-1)
log_probs = torch.log_softmax(logits, dim=-1)
loss = torch.mul(probs, log_probs - log_uniform)
loss_mean = np.sum([loss[b, :ylens[b], :].sum() for b in range(bs)]) / ylens.sum()
# assert loss_mean >= 0
return loss_mean
示例6
def focal_loss(logits, ys, ylens, alpha, gamma):
"""Compute focal loss.
Args:
logits (FloatTensor): `[B, T, vocab]`
ys (LongTensor): Indices of labels. `[B, L]`
ylens (IntTensor): `[B]`
alpha (float):
gamma (float):
Returns:
loss_mean (FloatTensor): `[1]`
"""
bs = ys.size(0)
log_probs = torch.log_softmax(logits, dim=-1)
probs_inv = -torch.softmax(logits, dim=-1) + 1
loss = -alpha * torch.mul(torch.pow(probs_inv, gamma), log_probs)
loss_mean = np.sum([loss[b, :ylens[b], :].sum() for b in range(bs)]) / ylens.sum()
return loss_mean
示例7
def greedy(self, eouts, elens):
"""Greedy decoding.
Args:
eouts (FloatTensor): `[B, T, enc_n_units]`
elens (np.ndarray): `[B]`
Returns:
hyps (np.ndarray): Best path hypothesis. `[B, L]`
"""
log_probs = torch.log_softmax(self.output(eouts), dim=-1)
best_paths = log_probs.argmax(-1) # `[B, L]`
hyps = []
for b in range(eouts.size(0)):
indices = [best_paths[b, t].item() for t in range(elens[b])]
# Step 1. Collapse repeated labels
collapsed_indices = [x[0] for x in groupby(indices)]
# Step 2. Remove all blank labels
best_hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)]
hyps.append(np.array(best_hyp))
return np.array(hyps)
示例8
def test_log_softmax():
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src.requires_grad_()
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
out.backward(torch.randn_like(out))
示例9
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
示例10
def forward(self, x, y, get_scores=False):
"""
Compute the loss, and optionally the scores.
"""
assert (y == self.pad_index).sum().item() == 0
if self.asm is False:
scores = self.proj(x).view(-1, self.n_words)
if self.label_smoothing == 0.0:
loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
else:
lprobs = torch.log_softmax(scores, dim=1)
nll_loss = -lprobs.gather(dim=-1, index=y.unsqueeze(1))
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
nll_loss, smooth_loss = nll_loss.sum(), smooth_loss.sum()
eps_i = self.label_smoothing / lprobs.size(-1)
loss = (1. - self.label_smoothing) * nll_loss + eps_i * smooth_loss
loss = loss / x.shape[0]
else:
_, loss = self.proj(x, y)
scores = self.proj.log_prob(x) if get_scores else None
return scores, loss
示例11
def init_step(self, beam, expected_len_pen):
# init_preds: [4, 3, 5, 6, 7] - no EOS's
init_scores = torch.log_softmax(torch.tensor(
[[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
init_scores = deepcopy(init_scores.repeat(
self.BATCH_SZ * self.BEAM_SZ, 1))
new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
expected_beam_scores, expected_preds_0 = new_scores \
.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
.topk(self.BEAM_SZ, dim=-1)
beam.advance(deepcopy(init_scores), self.random_attn())
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
self.assertTrue(beam.topk_ids.equal(expected_preds_0))
self.assertFalse(beam.is_finished.any())
self.assertFalse(beam.done)
return expected_beam_scores
示例12
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
示例13
def forward(ctx, logits, label, lb_smooth, lb_ignore):
# prepare label
num_classes = logits.size(1)
lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes
label = label.clone().detach()
ignore = label == lb_ignore
n_valid = (label != lb_ignore).sum()
label[ignore] = 0
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
ignore = ignore.nonzero()
_, M = ignore.size()
a, *b = ignore.chunk(M, dim=1)
mask = [a, torch.arange(logits.size(1)), *b]
lb_one_hot[mask] = 0
coeff = (num_classes - 1) * lb_neg + lb_pos
ctx.variables = coeff, mask, logits, lb_one_hot
loss = torch.log_softmax(logits, dim=1).neg_().mul_(lb_one_hot).sum(dim=1)
return loss
示例14
def init_step(self, beam, expected_len_pen):
# init_preds: [4, 3, 5, 6, 7] - no EOS's
init_scores = torch.log_softmax(torch.tensor(
[[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
init_scores = deepcopy(init_scores.repeat(
self.BATCH_SZ * self.BEAM_SZ, 1))
new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
expected_beam_scores, expected_preds_0 = new_scores \
.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
.topk(self.BEAM_SZ, dim=-1)
beam.advance(deepcopy(init_scores), self.random_attn())
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
self.assertTrue(beam.topk_ids.equal(expected_preds_0))
self.assertFalse(beam.is_finished.any())
self.assertFalse(beam.done)
return expected_beam_scores
示例15
def init_step(self, beam, expected_len_pen):
# init_preds: [4, 3, 5, 6, 7] - no EOS's
init_scores = torch.log_softmax(torch.tensor(
[[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
init_scores = deepcopy(init_scores.repeat(
self.BATCH_SZ * self.BEAM_SZ, 1))
new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
expected_beam_scores, expected_preds_0 = new_scores \
.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
.topk(self.BEAM_SZ, dim=-1)
beam.advance(deepcopy(init_scores), self.random_attn())
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
self.assertTrue(beam.topk_ids.equal(expected_preds_0))
self.assertFalse(beam.is_finished.any())
self.assertFalse(beam.done)
return expected_beam_scores
示例16
def forward(self, x, target):
"""Compute loss between x and target
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.reshape(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
示例17
def discriminate(self, z, edge_index):
"""Given node embeddings :obj:`z`, classifies the link relation
between node pairs :obj:`edge_index` to be either positive,
negative or non-existent.
Args:
x (Tensor): The input node features.
edge_index (LongTensor): The edge indices.
"""
value = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
value = self.lin(value)
return torch.log_softmax(value, dim=1)
示例18
def forward(self, x, edge_index):
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x_all = x.view(-1, 1, self.hidden_channels)
for conv in self.convs:
x = F.relu(conv(x_all, edge_index))
x = x.view(-1, 1, self.hidden_channels)
x_all = torch.cat([x_all, x], dim=1)
x = x_all[:, -1]
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return torch.log_softmax(x, dim=1)
示例19
def forward(self, features):
features = self.fc(features)
return torch.log_softmax(features, dim=-1)
示例20
def forward(self, x):
return th.log_softmax(
self.proj(x), dim=-1
)
示例21
def train_step(protonet, datax, datay, Ns, Nc, Nq):
optimizer.zero_grad()
Qx, Qy = protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
pred = torch.log_softmax(Qx, dim=-1)
loss = F.nll_loss(pred, Qy)
loss.backward()
optimizer.step()
acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
return loss, acc
示例22
def test_step(protonet, datax, datay, Ns, Nc, Nq):
Qx, Qy = protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
pred = torch.log_softmax(Qx, dim=-1)
loss = F.nll_loss(pred, Qy)
acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
return loss, acc
示例23
def cross_entropy_lsm(logits, ys, lsm_prob, ignore_index, training, normalize_length=False):
"""Compute cross entropy loss for label smoothing of sequence-to-sequence models.
Args:
logits (FloatTensor): `[B, T, vocab]`
ys (LongTensor): Indices of labels. `[B, L]`
lsm_prob (float): label smoothing probability
ignore_index (int): index for padding
normalize_length (bool): normalize XE loss by target sequence length
Returns:
loss_mean (FloatTensor): `[1]`
ppl (float): perplexity
"""
bs, _, vocab = logits.size()
ys = ys.view(-1)
logits = logits.view((-1, logits.size(2)))
if lsm_prob == 0 or not training:
loss = F.cross_entropy(logits, ys,
ignore_index=ignore_index, reduction='mean')
ppl = np.exp(loss.item())
if not normalize_length:
loss *= (ys != ignore_index).sum() / bs
else:
with torch.no_grad():
target_dist = logits.new_zeros(logits.size())
target_dist.fill_(lsm_prob / (vocab - 1))
mask = (ys == ignore_index)
ys_masked = ys.masked_fill(mask, 0)
target_dist.scatter_(1, ys_masked.unsqueeze(1), 1 - lsm_prob)
log_probs = torch.log_softmax(logits, dim=-1)
loss_sum = -torch.mul(target_dist, log_probs)
n_tokens = len(ys) - mask.sum().item()
denom = n_tokens if normalize_length else bs
loss = loss_sum.masked_fill(mask.unsqueeze(1), 0).sum() / denom
ppl = np.exp(loss.item()) if normalize_length else np.exp(loss.item() * bs / n_tokens)
return loss, ppl
示例24
def predict(self, ys, state=None, mems=None, cache=None):
"""Precict function for ASR.
Args:
ys (LongTensor): `[B, L]`
state:
- RNNLM: dict
hxs (FloatTensor): `[n_layers, B, n_units]`
cxs (FloatTensor): `[n_layers, B, n_units]`
- TransformerLM (LongTensor): `[B, L]`
- TransformerXL (list): length `n_layers + 1`, each of which contains a tensor`[B, L, d_model]`
mems (list):
cache (list):
Returns:
lmout (FloatTensor): `[B, L, vocab]`, used for LM integration such as cold fusion
state:
- RNNLM: dict
hxs (FloatTensor): `[n_layers, B, n_units]`
cxs (FloatTensor): `[n_layers, B, n_units]`
- TransformerLM (LongTensor): `[B, L]`
- TransformerXL (list): length `n_layers + 1`, each of which contains a tensor`[B, L, d_model]`
log_probs (FloatTensor): `[B, L, vocab]`
"""
logits, lmout, new_state = self.decode(ys, state, mems=mems, cache=cache,
incremental=True)
log_probs = torch.log_softmax(logits, dim=-1)
return lmout, new_state, log_probs
示例25
def recognize(self, h, recog_args):
"""Greedy search implementation for transformer-transducer.
Args:
h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
recog_args (Namespace): argument Namespace containing options
Returns:
hyp (list of dicts): 1-best decoding results
"""
hyp = {"score": 0.0, "yseq": [self.blank]}
ys = to_device(self, torch.tensor(hyp["yseq"], dtype=torch.long)).unsqueeze(0)
ys_mask = to_device(self, subsequent_mask(1).unsqueeze(0))
y, c = self.forward_one_step(ys, ys_mask, None)
for i, hi in enumerate(h):
ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)
logp, pred = torch.max(ytu, dim=0)
if pred != self.blank:
hyp["yseq"].append(int(pred))
hyp["score"] += float(logp)
ys = to_device(self, torch.tensor(hyp["yseq"]).unsqueeze(0))
ys_mask = to_device(
self, subsequent_mask(len(hyp["yseq"])).unsqueeze(0)
)
y, c = self.forward_one_step(ys, ys_mask, c)
return [hyp]
示例26
def forward_one_step(self, tgt, tgt_mask, memory, cache=None):
"""Forward one step.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param List[torch.Tensor] cache:
cached output list of (batch, max_time_out-1, size)
:return y, cache: NN output value and cache per `self.decoders`.
`y.shape` is (batch, maxlen_out, token)
:rtype: Tuple[torch.Tensor, List[torch.Tensor]]
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
# beam search API (see ScorerInterface)
示例27
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache = self.init_state()
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
# beam search API (see ScorerInterface)
示例28
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
示例29
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# batch 0 will always predict EOS. The other batches will predict
# non-eos scores.
for batch_sz in [1, 3]:
n_words = 100
_non_eos_idxs = [47]
valid_score_dist = torch.log_softmax(torch.tensor(
[6., 5.]), dim=0)
min_length = 5
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
samp = RandomSampling(
0, 1, 2, batch_sz, torch.device("cpu"), min_length,
False, set(), False, 30, 1., 1, lengths)
all_attns = []
for i in range(min_length + 4):
word_probs = torch.full(
(batch_sz, n_words), -float('inf'))
# "best" prediction is eos - that should be blocked
word_probs[0, eos_idx] = valid_score_dist[0]
# include at least one prediction OTHER than EOS
# that is greater than -1e20
word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1]
word_probs[1:, _non_eos_idxs[0] + i] = 0
attns = torch.randn(1, batch_sz, 53)
all_attns.append(attns)
samp.advance(word_probs, attns)
if i < min_length:
self.assertTrue(
samp.topk_scores[0].allclose(valid_score_dist[1]))
self.assertTrue(
samp.topk_scores[1:].eq(0).all())
elif i == min_length:
# now batch 0 has ended and no others have
self.assertTrue(samp.is_finished[0, :].eq(1).all())
self.assertTrue(samp.is_finished[1:, 1:].eq(0).all())
else: # i > min_length
break
示例30
def first_step(self, beam, expected_beam_scores, expected_len_pen):
# no EOS's yet
assert beam.is_finished.sum() == 0
scores_1 = torch.log_softmax(torch.tensor(
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 1.5, 0, 0, 0, 0, 0],
[0, 0, 0, 0, .49, .48, 0, 0],
[0, 0, 0, .2, .2, .2, .2, .2],
[0, 0, 0, .2, .2, .2, .2, .2]]
), dim=1)
scores_1 = scores_1.repeat(self.BATCH_SZ, 1)
beam.advance(deepcopy(scores_1), self.random_attn())
new_scores = scores_1 + expected_beam_scores.view(-1).unsqueeze(1)
expected_beam_scores, unreduced_preds = new_scores\
.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS)\
.topk(self.BEAM_SZ, -1)
expected_bptr_1 = unreduced_preds / self.N_WORDS
# [5, 3, 2, 6, 0], so beam 2 predicts EOS!
expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
self.assertTrue(beam.topk_scores.allclose(
expected_beam_scores / expected_len_pen))
self.assertTrue(beam.topk_ids.equal(expected_preds_1))
self.assertTrue(beam.current_backptr.equal(expected_bptr_1))
self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ)
self.assertTrue(beam.is_finished[:, 2].all()) # beam 2 finished
beam.update_finished()
self.assertFalse(beam.top_beam_finished.any())
self.assertFalse(beam.done)
return expected_beam_scores