def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
Implementing exclusive cumprod.
There is cumprod in pytorch, however there is no exclusive mode.
cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps
if dim == 0:
return return_tensor[:-1]
elif dim == 1:
return return_tensor[:, :-1]
elif dim == 2:
return return_tensor[:, :, :-1]
raise RuntimeError("Cumprod on dimension 3 and more is not implemented")
def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
An implementation of cumprod to prevent precision issue.
= [x1, x1x2, x1x2x3, ....]
= [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
= exp(cumsum(log(x)))
if (tensor + eps < 0).any().item():
raise RuntimeError(
"Safe cumprod can only take non-negative tensors as input."
"Consider use torch.cumprod if you want to calculate negative values."
log_tensor = torch.log(tensor + eps)
cumsum_log_tensor = torch.cumsum(log_tensor, dim)
exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
return exp_cumsum_log_tensor
def fake_cumprod(vb):
vb: [hei x wid]
-> NOTE: we are lazy here so now it only supports cumprod along wid
# real_cumprod = torch.cumprod(, 1)
vb = vb.unsqueeze(0)
mul_mask_vb = Variable(torch.zeros(vb.size(2), vb.size(1), vb.size(2))).type_as(vb)
for i in range(vb.size(2)):
mul_mask_vb[i, :, :i+1] = 1
add_mask_vb = 1 - mul_mask_vb
vb = vb.expand_as(mul_mask_vb) * mul_mask_vb + add_mask_vb
# vb =, 2).transpose(0, 2) # 0.1.12
vb =, 2, keepdim=True).transpose(0, 2) # 0.2.0
# print(real_cumprod - # NOTE: checked, ==0
return vb
def allocate(self, usage, write_gate):
# ensure values are not too small prior to cumprod.
usage = δ + (1 - δ) * usage
batch_size = usage.size(0)
# free list
sorted_usage, φ = T.topk(usage, self.mem_size, dim=1, largest=False)
# cumprod with exclusive=True
v = var(, 1).fill_(1))
cat_sorted_usage =, sorted_usage), 1)
prod_sorted_usage = T.cumprod(cat_sorted_usage, 1)[:, :-1]
sorted_allocation_weights = (1 - sorted_usage) * prod_sorted_usage.squeeze()
# construct the reverse sorting index
_, φ_rev = T.topk(φ, k=self.mem_size, dim=1, largest=False)
allocation_weights = sorted_allocation_weights.gather(1, φ_rev.long())
return allocation_weights.unsqueeze(1), usage
def forward(self, c):
self.eval += 1
fx = torch.mean(self.sigmod(self.x), 1)
# fx = - torch.sum(c * xx, 0) # onemax like
fx = - torch.cumprod(c * fx, 0).sum() # leading ones like
fx += 1e-8 * torch.sum(self.x ** 2)
return fx, fx
def dice_objective(self):
self_logprobs = torch.stack(self.self_logprobs, dim=1)
other_logprobs = torch.stack(self.other_logprobs, dim=1)
values = torch.stack(self.values, dim=1)
rewards = torch.stack(self.rewards, dim=1)
# apply discount:
cum_discount = torch.cumprod(hp.gamma * torch.ones(*rewards.size()), dim=1)/hp.gamma
discounted_rewards = rewards * cum_discount
discounted_values = values * cum_discount
# stochastics nodes involved in rewards dependencies:
dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)
# logprob of each stochastic nodes:
stochastic_nodes = self_logprobs + other_logprobs
# dice objective:
dice_objective = torch.mean(torch.sum(magic_box(dependencies) * discounted_rewards, dim=1))
if hp.use_baseline:
# variance_reduction:
baseline_term = torch.mean(torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1))
dice_objective = dice_objective + baseline_term
return -dice_objective # want to minimize -objective
def dice_objective(self):
self_logprobs = torch.stack(self.self_logprobs, dim=1)
other_logprobs = torch.stack(self.other_logprobs, dim=1)
values = torch.stack(self.values, dim=1)
rewards = torch.stack(self.rewards, dim=1)
# apply discount:
cum_discount = torch.cumprod(hp.gamma * torch.ones(*rewards.size()), dim=1)/hp.gamma
discounted_rewards = rewards * cum_discount
discounted_values = values * cum_discount
# stochastics nodes involved in rewards dependencies:
dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)
# logprob of each stochastic nodes:
stochastic_nodes = self_logprobs + other_logprobs
# dice objective:
dice_objective = torch.mean(torch.sum(magic_box(dependencies) * discounted_rewards, dim=1))
if hp.use_baseline:
# variance_reduction:
baseline_term = torch.mean(torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1))
dice_objective = dice_objective + baseline_term
return -dice_objective # want to minimize -objective
def _allocation(self, usage_vb, epsilon=1e-6):
computes allocation by sorting usage, a = a_t[\phi_t[j]]
variables needed:
usage_vb: [batch_size x mem_hei]
-> indicating current memory usage, this is equal to u_t in
the paper when we only have one write head, but for
multiple write heads, one should update the usage while
iterating through the write heads to take into account the
allocation returned by this function
alloc_vb: [batch_size x num_write_heads x mem_hei]
# ensure values are not too small prior to cumprod
usage_vb = epsilon + (1 - epsilon) * usage_vb
# NOTE: we sort usage in ascending order
sorted_usage_vb, indices_vb = torch.topk(usage_vb, k=self.mem_hei, dim=1, largest=False)
# to imitate tf.cumrprod(exclusive=True)
cat_sorted_usage_vb =, 1)).type(self.dtype), sorted_usage_vb), 1)[:, :-1]
# TODO: seems we have to wait for this PR:
prod_sorted_usage_vb = fake_cumprod(cat_sorted_usage_vb)
# prod_sorted_usage_vb = torch.cumprod(cat_sorted_usage_vb, dim=1) # TODO: use this once the PR is ready
# alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb # equ. (1) # 0.1.12
alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb.squeeze() # equ. (1) # 0.2.0
_, indices_vb = torch.topk(indices_vb, k=self.mem_hei, dim=1, largest=False)
alloc_weight_vb = alloc_weight_vb.gather(1, indices_vb)
return alloc_weight_vb
def exclusive_cumprod(x):
"""Exclusive cumulative product [a, b, c] => [1, a, a * b].
x (FloatTensor): `[B, H, qlen, klen]`
x (FloatTensor): `[B, H, qlen, klen]`
return torch.cumprod([x.new_ones(x.size(0), x.size(1), x.size(2), 1),
x[:, :, :, :-1]], dim=-1), dim=-1)
def cumprod(a, axis, dtype=None, out=None):
Return the cumulative product of elements along a given axis.
a : DNDarray
Input array.
axis : int
Axis along which the cumulative product is computed.
dtype : dtype, optional
Type of the returned array, as well as of the accumulator in which
the elements are multiplied. If *dtype* is not specified, it
defaults to the dtype of `a`, unless `a` has an integer dtype with
a precision less than that of the default platform integer. In
that case, the default platform integer is used instead.
out : DNDarray, optional
Alternative output array in which to place the result. It must
have the same shape and buffer length as the expected output
but the type of the resulting values will be cast if necessary.
cumprod : DNDarray
A new array holding the result is returned unless `out` is
specified, in which case a reference to out is returned.
>>> a = ht.full((3,3), 2)
>>> ht.cumprod(a, 0)
tensor([[2., 2., 2.],
[4., 4., 4.],
[8., 8., 8.])
return operations.__cum_op(a, torch.cumprod, MPI.PROD, torch.mul, 1, axis, dtype, out)
# Alias support
def tor_err_at_ks(sys_sorted_labels, ks=None, multi_level_rele=True, max_rele_level=None):
:param sys_sorted_labels: the standard labels sorted in descending order according to predicted relevance scores
:param ks:
:param multi_level_rele:
:param max_rele_level:
valid_max = sys_sorted_labels.size(0)
used_ks = [k for k in ks if k <= valid_max] if valid_max < max(ks) else ks
max_cutoff = max(used_ks)
inds = torch.from_numpy(np.asarray(used_ks) - 1)
if multi_level_rele:
positions = torch.arange(max_cutoff) + 1.0
expt_ranks = 1.0 / positions # expected stop positions
tor_max_rele = torch.Tensor([max_rele_level]).float()
satis_pros = (torch.pow(2.0, sys_sorted_labels[0:max_cutoff]) - 1.0)/torch.pow(2.0, tor_max_rele)
non_satis_pros = torch.ones(max_cutoff) - satis_pros
cum_non_satis_pros = torch.cumprod(non_satis_pros, dim=0)
cascad_non_satis_pros = positions
cascad_non_satis_pros[1:max_cutoff] = cum_non_satis_pros[0:max_cutoff-1]
expt_satis_ranks = expt_ranks * satis_pros * cascad_non_satis_pros # w.r.t. all rank positions
err_at_ranks = torch.cumsum(expt_satis_ranks, dim=0)
err_at_ks = err_at_ranks[inds]
if valid_max < max(ks):
padded_err_at_ks = torch.zeros(len(ks))
padded_err_at_ks[0:len(used_ks)] = err_at_ks
return padded_err_at_ks
return err_at_ks
raise NotImplementedError
def torch_ideal_err(sorted_labels, k=10, point=True, gpu=False):
assert sorted_labels.size(0) >= k
max_label = torch.max(sorted_labels)
labels = sorted_labels[0:k]
satis_pros = (torch.pow(2.0, labels) - 1.0) / torch.pow(2.0, max_label)
unsatis_pros = torch.ones_like(labels) - satis_pros
cum_unsatis_pros = torch.cumprod(unsatis_pros, dim=0)
if gpu:
ranks = torch.arange(k).type(tensor) + 1.0
expt_ranks = 1.0 / ranks
ranks = torch.arange(k) + 1.0
expt_ranks = 1.0 / ranks
cascad_unsatis_pros = ranks
cascad_unsatis_pros[1:k] = cum_unsatis_pros[0:k-1]
expt_satis_ranks = expt_ranks * satis_pros * cascad_unsatis_pros # w.r.t. all rank positions
if point: # a specific position
ideal_err = torch.sum(expt_satis_ranks, dim=0)
return ideal_err
ideal_err_at_ks = torch.cumsum(expt_satis_ranks, dim=0)
return ideal_err_at_ks
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True):
assert batch_sorted_labels.size(1) > k
batch_max = torch.max(batch_sorted_labels, dim=1)
batch_labels = batch_sorted_labels[:, 0:k]
batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(2.0, batch_max)
batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros
batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1)
positions = torch.arange(k) + 1.0
positions = positions.view(1, -1)
positions = torch.repeat_interleave(positions, batch_sorted_labels.size(0), dim=0)
batch_expt_ranks = 1.0 / positions
cascad_unsatis_pros = positions
cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k-1]
expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros # w.r.t. all rank positions
if point:
batch_errs = torch.sum(expt_satis_ranks, dim=1)
return batch_errs
batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1)
return batch_err_at_ks
def torch_nerr_at_ks(sys_sorted_labels, ideal_sorted_labels, ks=None, multi_level_rele=True):
:param sys_sorted_labels: the standard labels sorted in descending order according to predicted relevance scores
:param ks:
:param multi_level_rele:
valid_max = sys_sorted_labels.size(0)
used_ks = [k for k in ks if k <= valid_max] if valid_max < max(ks) else ks
max_cutoff = max(used_ks)
inds = torch.from_numpy(np.asarray(used_ks) - 1)
if multi_level_rele:
positions = torch.arange(max_cutoff) + 1.0
expt_ranks = 1.0 / positions # expected stop positions
tor_max_rele = torch.max(sys_sorted_labels)
satis_pros = (torch.pow(2.0, sys_sorted_labels[0:max_cutoff]) - 1.0)/torch.pow(2.0, tor_max_rele)
non_satis_pros = torch.ones(max_cutoff) - satis_pros
cum_non_satis_pros = torch.cumprod(non_satis_pros, dim=0)
cascad_non_satis_pros = positions
cascad_non_satis_pros[1:max_cutoff] = cum_non_satis_pros[0:max_cutoff-1]
expt_satis_ranks = expt_ranks * satis_pros * cascad_non_satis_pros # w.r.t. all rank positions
err_at_ks = torch.cumsum(expt_satis_ranks, dim=0)
ideal_err_at_ks = torch_ideal_err(ideal_sorted_labels, k=max_cutoff, point=False)
tmp_nerr_at_ks = err_at_ks/ideal_err_at_ks
nerr_at_ks = tmp_nerr_at_ks[inds]
if valid_max < max(ks):
padded_nerr_at_ks = torch.zeros(len(ks))
padded_nerr_at_ks[0:len(used_ks)] = nerr_at_ks
return padded_nerr_at_ks
return nerr_at_ks
raise NotImplementedError
def __call__(self, x):
Evaluates E_S[softmax_{customers} min_{i \in S} dist(customer, i)] where
the expectation is over the set of facility locations S. Every
location is included in S independently with probability x_i.
x_sort = x[self.order]
probs = 1 - torch.cumprod(1 - x_sort, dim=1)
vals = self.dmax + (self.m*probs).sum(dim=1)
if self.hardmax:
return vals.max()
weights = torch.softmax(self.temp*vals, dim=0)
return, weights)
def forward(self, emb, parser_state):
emb_last, cum_gate = parser_state
ntimestep = emb.size(0)
emb_last =[emb_last, emb], dim=0)
emb = emb_last.transpose(0, 1).transpose(1, 2) # bsz, ninp, ntimestep + nlookback
gates = self.gate(emb) # bsz, 2, ntimestep
gate = gates[:, 0, :]
gate_next = gates[:, 1, :]
cum_gate =[cum_gate, gate], dim=1)
gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
dim=2) # bsz, ntimestep, nslots
if self.hard:
memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
memory_gate = F.sigmoid(
(gate[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots
memory_gate = torch.cumprod(memory_gate, dim=2) # bsz, ntimestep, nlookback+1
memory_gate = torch.unbind(memory_gate, dim=1)
if self.hard:
memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
memory_gate_next = F.sigmoid(
(gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots
memory_gate_next = torch.cumprod(memory_gate_next, dim=2) # bsz, ntimestep, nlookback+1
memory_gate_next = torch.unbind(memory_gate_next, dim=1)
return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:])
def forward(self, emb, parser_state):
emb_last, cum_gate = parser_state
ntimestep = emb.size(0)
emb_last =[emb_last, emb], dim=0)
emb = emb_last.transpose(0, 1).transpose(1, 2) # bsz, ninp, ntimestep + nlookback
gates = self.gate(emb) # bsz, 2, ntimestep
gate = gates[:, 0, :]
gate_next = gates[:, 1, :]
cum_gate =[cum_gate, gate], dim=1)
gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
dim=2) # bsz, ntimestep, nslots
if self.hard:
memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
memory_gate = F.sigmoid(
(gate[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots
memory_gate = torch.cumprod(memory_gate, dim=2) # bsz, ntimestep, nlookback+1
memory_gate = torch.unbind(memory_gate, dim=1)
if self.hard:
memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
memory_gate_next = F.sigmoid(
(gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots
memory_gate_next = torch.cumprod(memory_gate_next, dim=2) # bsz, ntimestep, nlookback+1
memory_gate_next = torch.unbind(memory_gate_next, dim=1)
return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:])
def _discount_reward_tensor_1d(reward: torch.Tensor,
sequence_length: Optional[torch.LongTensor],
discount: float = 1.) -> torch.Tensor:
r"""Computes discounted reward.
reward: 1D Tensor with shape `[batch_size]`.
sequence_length: A Tensor of shape `[batch_size]`.
Time steps beyond the respective sequence lengths will be masked.
discount (float): A scalar. The discount factor.
A 2D Tensor of the discounted reward.
if sequence_length is None:
raise ValueError('sequence_length must not be `None` for 1D reward.')
if not isinstance(sequence_length, torch.Tensor):
sequence_length = torch.tensor(
sequence_length, dtype=torch.int64, device=reward.device)
batch_size = reward.shape[0]
max_seq_length = torch.max(sequence_length)
dtype: torch.dtype = reward.dtype
if discount == 1.:
disc_reward = reward.unsqueeze(-1).expand(batch_size, max_seq_length)
mask = sequence_mask(sequence_length, dtype=dtype)
mask =[:, 1:], torch.zeros_like(mask[:, -1:])), dim=1)
# Make each row = [discount, ..., discount, 1, ..., 1]
dmat = mask * discount + (1 - mask)
dmat = torch.flip(dmat, (1,))
dmat = torch.cumprod(dmat, dim=1)
dmat = torch.flip(dmat, (1,))
disc_reward = dmat * reward.unsqueeze(-1)
disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype)
return disc_reward
def compute_advantages(discount, gae_lambda, max_path_length, baselines,
"""Calculate advantages.
Advantages are a discounted cumulative sum.
Calculate advantages using a baseline according to Generalized Advantage
Estimation (GAE)
The discounted cumulative sum can be computed using conv2d with filter.
[1, (discount * gae_lambda), (discount * gae_lambda) ^ 2, ...]
where the length is same with max_path_length.
baselines and rewards are also has same shape.
[ [b_11, b_12, b_13, ... b_1n],
[b_21, b_22, b_23, ... b_2n],
[b_m1, b_m2, b_m3, ... b_mn] ]
[ [r_11, r_12, r_13, ... r_1n],
[r_21, r_22, r_23, ... r_2n],
[r_m1, r_m2, r_m3, ... r_mn] ]
discount (float): RL discount factor (i.e. gamma).
gae_lambda (float): Lambda, as used for Generalized Advantage
Estimation (GAE).
max_path_length (int): Maximum length of a single rollout.
baselines (torch.Tensor): A 2D vector of value function estimates with
shape (N, T), where N is the batch dimension (number of episodes)
and T is the maximum path length experienced by the agent. If an
episode terminates in fewer than T time steps, the remaining
elements in that episode should be set to 0.
rewards (torch.Tensor): A 2D vector of per-step rewards with shape
(N, T), where N is the batch dimension (number of episodes) and T
is the maximum path length experienced by the agent. If an episode
terminates in fewer than T time steps, the remaining elements in
that episode should be set to 0.
torch.Tensor: A 2D vector of calculated advantage values with shape
(N, T), where N is the batch dimension (number of episodes) and T
is the maximum path length experienced by the agent. If an episode
terminates in fewer than T time steps, the remaining values in that
episode should be set to 0.
adv_filter = torch.full((1, 1, 1, max_path_length - 1),
discount * gae_lambda)
adv_filter = torch.cumprod(F.pad(adv_filter, (1, 0), value=1), dim=-1)
deltas = (rewards + discount * F.pad(baselines, (0, 1))[:, 1:] - baselines)
deltas = F.pad(deltas, (0, max_path_length - 1)).unsqueeze(0).unsqueeze(0)
advantages = F.conv2d(deltas, adv_filter, stride=1).reshape(rewards.shape)
return advantages