Python源码示例:torch.tensordot()
示例1
def idct_8x8(image):
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
alpha = torch.FloatTensor(np.outer(alpha, alpha)).cuda()
image = image * alpha
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
(2 * v + 1) * y * np.pi / 16)
# result = 0.25 * torch.tensordot(image, torch.as_tensor(tensor, device="cuda"), dims=2) + 128
result = 0.25 * tensordot_pytorch(image, torch.as_tensor(tensor, device="cuda"), dims=2) + 128
result.view(image.size())
return result
# -3. Block joining
示例2
def get_obs(Asymm, H, Sx, Sy, Sz, C, E ):
# A(phy,u,l,d,r), C(d,r), E(u,r,d)
Da = Asymm.size()
Td = torch.einsum('mefgh,nabcd->eafbgchdmn',(Asymm,Asymm)).contiguous().view(Da[1]**2, Da[2]**2, Da[3]**2, Da[4]**2, Da[0], Da[0])
#print( torch.dist( Td, Td.permute(0,3,2,1,4,5) ) ) # test left-right reflection symmetry of Td
CE = torch.tensordot(C,E,([1],[0])) # C(1d)E(dga)->CE(1ga)
EL = torch.tensordot(E,CE,([2],[0])) # E(2e1)CE(1ga)->EL(2ega) use E(2e1) == E(1e2)
EL = torch.tensordot(EL,Td,([1,2],[1,0])) # EL(2ega)T(gehbmn)->EL(2ahbmn)
EL = torch.tensordot(EL,CE,([0,2],[0,1])) # EL(2ahbmn)CE(2hc)->EL(abmnc), use CE(2hc) == CE(1ga)
Rho = torch.tensordot(EL,EL,([0,1,4],[0,1,4])).permute(0,2,1,3).contiguous().view(Da[0]**2,Da[0]**2)
# print( (Rho-Rho.t()).norm() )
Rho = 0.5*(Rho + Rho.t())
Tnorm = Rho.trace()
Energy = torch.mm(Rho,H).trace()/Tnorm
Mx = torch.mm(Rho,Sx).trace()/Tnorm
My = torch.mm(Rho,Sy).trace()/Tnorm
Mz = torch.mm(Rho,Sz).trace()/Tnorm
#print("Tnorm = %g, Energy = %g " % (Tnorm.item(), Energy.item()) )
return Energy, Mx, My, Mz
示例3
def forward(self, inputs):
embeds_vec_list = inputs
row = []
col = []
for r, c in itertools.combinations(embeds_vec_list, 2):
row.append(r)
col.append(c)
p = torch.cat(row, dim=1)
q = torch.cat(col, dim=1)
inner_product = p * q
bi_interaction = inner_product
attention_temp = F.relu(torch.tensordot(
bi_interaction, self.attention_W, dims=([-1], [0])) + self.attention_b)
self.normalized_att_score = F.softmax(torch.tensordot(
attention_temp, self.projection_h, dims=([-1], [0])), dim=1)
attention_output = torch.sum(
self.normalized_att_score * bi_interaction, dim=1)
attention_output = self.dropout(attention_output) # training
afm_out = torch.tensordot(
attention_output, self.projection_p, dims=([-1], [0]))
return afm_out
示例4
def forward(self, inputs):
if len(inputs.shape) != 3:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))
querys = torch.tensordot(inputs, self.W_Query,
dims=([-1], [0])) # None F D*head_num
keys = torch.tensordot(inputs, self.W_key, dims=([-1], [0]))
values = torch.tensordot(inputs, self.W_Value, dims=([-1], [0]))
# head_num None F D
querys = torch.stack(torch.split(
querys, self.att_embedding_size, dim=2))
keys = torch.stack(torch.split(keys, self.att_embedding_size, dim=2))
values = torch.stack(torch.split(
values, self.att_embedding_size, dim=2))
inner_product = torch.einsum(
'bnik,bnjk->bnij', querys, keys) # head_num None F F
self.normalized_att_scores = F.softmax(
inner_product, dim=-1) # head_num None F F
result = torch.matmul(self.normalized_att_scores,
values) # head_num None F D
result = torch.cat(torch.split(result, 1, ), dim=-1)
result = torch.squeeze(result, dim=0) # None F D*head_num
if self.use_res:
result += torch.tensordot(inputs, self.W_Res, dims=([-1], [0]))
result = F.relu(result)
return result
示例5
def forward(self, inputs):
x_0 = inputs.unsqueeze(2)
x_l = x_0
for i in range(self.layer_num):
xl_w = torch.tensordot(x_l, self.kernels[i], dims=([1], [0]))
dot_ = torch.matmul(x_0, xl_w)
x_l = dot_ + self.bias[i] + x_l
x_l = torch.squeeze(x_l, dim=2)
return x_l
示例6
def tensordot(x, y, dims):
"""
Wrapper around :func:`torch.tensordot` or :func:`np.tensordot`
to operate on real-valued Funsors.
Note this operates only on the ``output`` tensor. To perform sum-product
contractions on named dimensions, instead use ``+`` and
:class:`~funsor.terms.Reduce`.
Arguments should satisfy::
len(x.shape) >= dims
len(y.shape) >= dims
dims == 0 or x.shape[-dims:] == y.shape[:dims]
:param Funsor x: A left hand argument.
:param Funsor y: A y hand argument.
:param int dims: The number of dimension of overlap of output shape.
:rtype: Funsor
"""
assert dims >= 0
assert len(x.shape) >= dims
assert len(y.shape) >= dims
assert dims == 0 or x.shape[-dims:] == y.shape[:dims]
x_start, x_end = 0, len(x.output.shape)
y_start = x_end - dims
y_end = y_start + len(y.output.shape)
symbols = 'abcdefghijklmnopqrstuvwxyz'
equation = '{},{}->{}'.format(symbols[x_start:x_end],
symbols[y_start:y_end],
symbols[x_start:y_start] + symbols[x_end:y_end])
return Einsum(equation, (x, y))
示例7
def _numeric_tensordot(x, y, dim):
if get_backend() == "torch":
import torch
return torch.tensordot(x, y, dim)
else:
return np.tensordot(x, y, axes=dim)
示例8
def test_tensor_tensordot(x_shape, xy_shape, y_shape):
x = randn(x_shape + xy_shape)
y = randn(xy_shape + y_shape)
dim = len(xy_shape)
actual = tensordot(Tensor(x), Tensor(y), dim)
expected = Tensor(_numeric_tensordot(x, y, dim))
assert_close(actual, expected, atol=1e-5, rtol=None)
示例9
def rgb_to_ycbcr(image):
matrix = np.array(
[[65.481, 128.553, 24.966],
[-37.797, -74.203, 112.],
[112., -93.786, -18.214]],
dtype=np.float32).T / 255
shift = torch.as_tensor([16., 128., 128.], device="cuda")
# result = torch.tensordot(image, torch.as_tensor(matrix, device="cuda"), dims=1) + shift
result = tensordot_pytorch(image, matrix, dims=1) + shift
result.view(image.size())
return result
示例10
def rgb_to_ycbcr_jpeg(image):
matrix = np.array(
[[0.299, 0.587, 0.114],
[-0.168736, -0.331264, 0.5],
[0.5, -0.418688, -0.081312]],
dtype=np.float32).T
shift = torch.as_tensor([0., 128., 128.], device="cuda")
# result = torch.tensordot(image, torch.as_tensor(matrix, device="cuda"), dims=1) + shift
result = tensordot_pytorch(image, torch.as_tensor(matrix, device='cuda'), dims=1) + shift
result.view(image.size())
return result
# 2. Chroma subsampling
示例11
def dct_8x8(image):
image = image - 128
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
(2 * y + 1) * v * np.pi / 16)
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
scale = torch.FloatTensor(np.outer(alpha, alpha) * 0.25).cuda()
#result = scale * torch.tensordot(image, torch.as_tensor(tensor, device="cuda"), dims=2)
result = scale * tensordot_pytorch(image, torch.as_tensor(tensor, device="cuda"), dims=2)
result.view(image.size())
return result
示例12
def ycbcr_to_rgb(image):
matrix = np.array(
[[298.082, 0, 408.583],
[298.082, -100.291, -208.120],
[298.082, 516.412, 0]],
dtype=np.float32).T / 256
shift = torch.as_tensor([-222.921, 135.576, -276.836], device="cuda")
# result = torch.tensordot(image, torch.tensor(matrix, device="cuda"), dims=1) + shift
result = tensordot_pytorch(image, torch.tensor(matrix, device="cuda"), dims=1) + shift
result.view(image.size())
return result
示例13
def test_tensordot():
backend = pytorch_backend.PyTorchBackend()
a = backend.convert_to_tensor(2 * np.ones((2, 3, 4)))
b = backend.convert_to_tensor(np.ones((2, 3, 4)))
actual = backend.tensordot(a, b, ((1, 2), (1, 2)))
expected = np.array([[24.0, 24.0], [24.0, 24.0]])
np.testing.assert_allclose(expected, actual)
示例14
def test_eigsh_lanczos_0():
#this test should just not crash
dtype = torch.float64
backend = pytorch_backend.PyTorchBackend()
D = 4
init = backend.randn((2, 2, 2), dtype=dtype)
tmp = backend.randn((8, 8), dtype=dtype)
H = tmp + backend.transpose(backend.conj(tmp), (1, 0))
H = H.reshape([2, 2, 2, 2, 2, 2])
def mv(x, mat):
return torch.tensordot(mat, x, ([0, 3, 5], [2, 0, 1])).permute([2, 0, 1])
backend.eigsh_lanczos(mv, [H], init, num_krylov_vecs=D)
示例15
def forward(self, inputs, embed=True):
if embed:
return torch.nn.functional.embedding(inputs, self.w)
else:
return torch.tensordot(inputs, self.w.t(), 1) + self.b
示例16
def sample(self, p, z=None):
"""Input p to be shaped [T,B,A,P] or [B,A,P], A: number of actions, P:
number of atoms. Optional input z is domain of atom-values, shaped
[P]. Vector epsilon of lenght B will apply across Batch dimension."""
q = torch.tensordot(p, z or self.z, dims=1)
return super().sample(q)
示例17
def reweight(self, msa1hot):
# Reweight
seqlen = msa1hot.size(1)
id_min = seqlen * self.msa_cutoff
id_mtx = torch.tensordot(msa1hot, msa1hot, [[1, 2], [1, 2]])
id_mask = id_mtx > id_min
weights = 1.0 / id_mask.float().sum(-1)
return weights
示例18
def reweight(self, msa1hot, eps=1e-9):
# Reweight
seqlen = msa1hot.size(2)
id_min = seqlen * self.msa_cutoff
id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0)
id_mask = id_mtx > id_min
weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps)
return weights
示例19
def calculate_reward(
self,
slots: SlateSlots,
rewards: Optional[SlateSlotValues] = None,
slot_values: Optional[SlateSlotValues] = None,
slot_weights: Optional[SlateSlotValues] = None,
) -> float:
if slot_values is None:
assert rewards is not None
slot_values = self.slot_values(rewards)
values = slot_values.values.to(device=self._device)
if slot_weights is None:
slot_weights = self.slot_weights(slots)
weights = slot_weights.values.to(device=self._device)
return torch.tensordot(values, weights, dims=([0], [0])).item()
示例20
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
if log_slot_expects is None:
logger.warning(" Log slot distribution not available")
return None
tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
if tgt_slot_expects is None:
logger.warning(" Target slot distribution not available")
return None
slate_size = len(sample.context.slots)
slot_weights = sample.slot_weights
if slot_weights is None:
slot_weights = SlateSlotValues(torch.ones(slate_size, dtype=torch.double))
weights = slot_weights.values.to(device=self._device)
if sample.slot_probabilities is not None:
weights *= sample.slot_probabilities.values
h = torch.zeros(slate_size, dtype=torch.double, device=self._device)
p = torch.zeros(slate_size, dtype=torch.double, device=self._device)
i = 0
for slot, item in sample.log_slate:
h[i] = tgt_slot_expects[slot][item]
p[i] = log_slot_expects[slot][item]
i += 1
nu = torch.tensordot(h, weights, dims=([0], [0]))
de = torch.tensordot(p, weights, dims=([0], [0]))
if nu == de:
weight = 1.0
elif nu == 0:
weight = 0.0
elif de == 0:
return None
else:
weight = self._weight_clamper(nu / de)
return EstimatorSampleResult(
sample.log_reward,
sample.log_reward * weight,
sample.ground_truth_reward,
weight,
)
# pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently.
示例21
def soft_embedding_lookup(embedding, soft_ids):
r"""Transforms soft ids (e.g., probability distribution over ids) into
embeddings, by mixing the embedding vectors with the soft weights.
Args:
embedding: A Tensor of shape ``[num_classes] + embedding-dim``
containing the embedding vectors. Embedding can have dimensionality
> 1, i.e., :attr:`embedding` can be of shape
``[num_classes, emb_dim_1, emb_dim_2, ...]``
soft_ids: A Tensor of weights (probabilities) used to mix the
embedding vectors.
Returns:
A Tensor of shape ``shape(soft_ids)[:-1] + shape(embedding)[1:]``. For
example, if ``shape(soft_ids) = [batch_size, max_time, vocab_size]``
and ``shape(embedding) = [vocab_size, emb_dim]``, then the returned
tensor has shape ``[batch_size, max_time, emb_dim]``.
Example::
softmax = torch.nn.Softmax()
decoder_outputs, ... = decoder(...)
soft_seq_emb = soft_embedding_lookup(
embedding, softmax(decoder_outputs.logits))
"""
return torch.tensordot(soft_ids, embedding, dims=([-1], [0]))
示例22
def forward(self):
self.output = self.net.forward(self.input)
x = self.output
[n, l , classes] = x.size()
x = x.view(n * l, classes)
# print(x)
self.loss_ce = F.cross_entropy(x, self.target)
if self.opt.entropy_loss_coeff > 0:
S = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
S = -1.0 * S.mean()
self.loss_ce += self.opt.entropy_loss_coeff * S
self.metric_accuracy = (torch.argmax(x,1) == self.target).sum().float()/len(self.target)
#TODO: implement humaneness_reg maybe
# problem is we don't have past notes available in input, so need to do that differently
# just use output I guess :P
# step_size = self.opt.step_size
# humaneness_delta = constants.HUMAN_DELTA
# window_size = int(humaneness_delta/step_size)
#
# receptive_field = self.net.module.receptive_field
# notes = (torch.argmax(input[:,-5:,receptive_field//2-(window_size):receptive_field//2],1)==4).float()
# distance_factor = torch.tensor(np.exp(-2*np.arange(window_size,0,-1)/window_size)).float().cuda()
# if self.opt.entropy_loss_coeff > 0:
# weights = torch.tensordot(notes,distance_factor,dims=1)
# humaneness_reg = F.cross_entropy(x,torch.zeros(weights.shape).long().cuda(), reduction='none')
# humaneness_reg = torch.dot(humaneness_reg, weights)
# self.loss_humaneness_reg = humaneness_reg
# # self.loss_humaneness_reg = 0
# self.loss_total = self.loss_ce + self.opt.humaneness_reg_coeff * self.loss_humaneness_reg
# else:
# self.loss_humaneness_reg = 0
# self.loss_total = self.loss_ce
self.loss_humaneness_reg = 0
self.loss_total = self.loss_ce
示例23
def forward(self):
self.output = self.net.forward(self.input)
x = self.output
[n, channels, classes, l] = x.size()
x = x.transpose(1, 3).contiguous()
x = x.view(n * l * channels, classes)
self.loss_ce = F.cross_entropy(x, self.target)
if self.opt.entropy_loss_coeff > 0:
S = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
S = -1.0 * S.mean()
self.loss_ce += self.opt.entropy_loss_coeff * S
self.metric_accuracy = (torch.argmax(x,1) == self.target).sum().float()/len(self.target)
step_size = self.opt.step_size
humaneness_delta = constants.HUMAN_DELTA
window_size = int(humaneness_delta/step_size)
# print(humaneness_reg.shape)
receptive_field = self.net.module.receptive_field
# weights = torch.sum(torch.argmax(self.input[:,-5:,receptive_field//2-(window_size-1):receptive_field//2],1)==4,1).float()
notes = (torch.argmax(self.input[:,-5:,receptive_field//2-(window_size):receptive_field//2],1)==4).float()
distance_factor = torch.tensor(np.exp(-2*np.arange(window_size,0,-1)/window_size)).float().cuda()
# print(notes.shape, distance_factor.shape)
weights = torch.tensordot(notes,distance_factor,dims=1)
# print()
# print(self.input[:,-5:,receptive_field//2-(window_size-1):receptive_field//2].shape)
# self.loss_humaneness_reg = F.relu(humaneness_reg-1).mean()
# humaneness_reg = -F.cross_entropy(x,torch.ones(weights.shape).long().cuda(), reduction='none')
humaneness_reg = F.cross_entropy(x,torch.zeros(weights.shape).long().cuda(), reduction='none')
humaneness_reg = torch.dot(humaneness_reg, weights)
self.loss_humaneness_reg = humaneness_reg
self.loss_total = self.loss_ce + self.opt.humaneness_reg_coeff * self.loss_humaneness_reg
示例24
def _get_torch_and_device():
global _TORCH_DEVICE
global _TORCH_HAS_TENSORDOT
if _TORCH_DEVICE is None:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
_TORCH_DEVICE = torch, device
_TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')
return _TORCH_DEVICE
示例25
def forward(self, beta, pose, trans=None, simplify=False):
"""This module takes betas and poses in a batched manner.
A pose is 3 * K + 3 (= self.kintree_table.shape[1] * 3) parameters, where K is the number of joints.
A beta is a vector of size self.shapedirs.shape[2], that parameterizes the body shape.
Since this is batched, multiple betas and poses should be concatenated along zeroth dimension.
See http://files.is.tue.mpg.de/black/papers/SMPL2015.pdf for more info.
"""
batch_size = beta.shape[0] # Size of zeroth dimension.
# The body shape is decomposed with principal component analysis from many subjects,
# where self.v_template is the average value. Then shapedirs is a subset of the orthogonal directions, and
# a the betas are the values when the subject is projected onto these. v_shaped is the "restored" subject.
v_shaped = torch.tensordot(beta, self.shapedirs, dims=([1], [2])) + self.v_template
# We turn the rotation vectors into rotation matrices.
R_cube = self.rodrigues(pose.reshape(-1, 1, 3)).reshape(batch_size, -1, 3, 3)
J = self.regress_joints(v_shaped) # Joints in T-pose (for limb lengths)
if not simplify:
# Add pose blend shapes. (How joint angles morphs the surface)
# Now calculate how joints affects the body shape.
lrotmin = R_cube[:, 1:] - self.eye
lrotmin = lrotmin.reshape(batch_size, -1)
v_shaped += torch.tensordot(lrotmin, self.posedirs, dims=([1], [2]))
# Now we have the un-posed body shape. Convert to homogeneous coordinates.
rest_shape_h = torch.cat((v_shaped, v_shaped.new_ones(1).expand(*v_shaped.shape[:-1], 1)), 2)
G = [self.rotate_translate(R_cube[:, 0], J[:, 0])]
for i in range(1, self.kintree_table.shape[1]):
G.append(
torch.bmm(
G[self.parent[i]],
self.rotate_translate(R_cube[:, i], J[:, i] - J[:, self.parent[i]])))
G = torch.stack(G, 1)
Jtr = G[..., :4, 3].clone()
G = G - self.pack(torch.matmul(G, torch.cat([J, J.new_zeros(1).expand(*J.shape[:2], 1)], dim=2).unsqueeze(-1)))
# T = torch.tensordot(self.weights, G, dims=([1], [1]))
# v = T.reshape(-1, 4, 4).bmm(rest_shape_h.reshape(-1, 4, 1)).reshape(batch_size, -1, 4)
# Two next lines are a memory bottleneck.
T = torch.tensordot(G, self.weights, dims=([1], [1])).permute(0, 3, 1, 2)
v = torch.matmul(T, torch.reshape(rest_shape_h, (batch_size, -1, 4, 1))).reshape(batch_size, -1, 4)
if trans is not None:
trans = trans.unsqueeze(1)
v[..., :3] += trans
Jtr[..., :3] += trans
return v, Jtr
示例26
def forward(self, heatmap, points, target_hm, target_points):
if(self.loss_type == 'l2_softargmax' or self.loss_type == 'l2_sm'):
mse_loss = (points - target_points) ** 2
location_loss = mse_loss.sum(2).sum(1).mean()
elif(self.loss_type == 'l2_heatmap' or self.loss_type == 'l2_hm'):
mse_loss = (heatmap - target_hm) ** 2
location_loss = mse_loss.sum(3).sum(2).sum(1).mean()
elif(self.loss_type == 'l1_softargmax' or self.loss_type == 'l1_sm'):
l1_loss = torch.abs(points - target_points)
location_loss = l1_loss.sum(2).sum(1).mean()
else:
print("Did not recognize loss function selection!")
sys.exit(1)
if self.include_geo:
# Loss on co-linearity of points along side of cone
v53 = F.normalize(points[:, 5] - points[:, 3], dim=1)
v31 = F.normalize(points[:, 3] - points[:, 1], dim=1)
vA = 1.0 - torch.tensordot(v31, v53, dims=([1], [1]))
v10 = F.normalize(points[:, 1] - points[:, 0], dim=1)
vB = 1.0 - torch.tensordot(v10, v31, dims=([1], [1]))
v64 = F.normalize(points[:, 6] - points[:, 4], dim=1)
v42 = F.normalize(points[:, 4] - points[:, 2], dim=1)
vC = 1.0 - torch.tensordot(v64, v42, dims=([1], [1]))
v20 = F.normalize(points[:, 2] - points[:, 0], dim=1)
vD = 1.0 - torch.tensordot(v42, v20, dims=([1], [1]))
# Loss on horizontals on cones (color boundaries)
h21 = F.normalize(points[:, 2] - points[:, 1], dim=1)
h43 = F.normalize(points[:, 4] - points[:, 3], dim=1)
hA = 1.0 - torch.tensordot(h43, h21, dims=([1], [1]))
h65 = F.normalize(points[:, 6] - points[:, 5], dim=1)
hB = 1.0 - torch.tensordot(h65, h43, dims=([1], [1]))
geo_loss = self.geo_loss_gamma_horz * (hA + hB).mean() / 2 + self.geo_loss_gamma_vert * (vA + vB + vC + vD).mean() / 4
else:
geo_loss = torch.tensor(0)
#print('----------')
#print('Geo Loss: ' + str(geo_loss.item()))
#print('Location Loss: ' + str(location_loss.item()))
return location_loss, geo_loss, location_loss+geo_loss
示例27
def renormalize(*tensors):
# T(up,left,down,right), u=up, l=left, d=down, r=right
# C(d,r), EL(u,r,d), EU(l,d,r)
C, E, T, chi = tensors
dimT, dimE = T.shape[0], E.shape[0]
D_new = min(dimE*dimT, chi)
# step 1: contruct the density matrix Rho
Rho = torch.tensordot(C,E,([1],[0])) # C(ef)*EU(fga)=Rho(ega)
Rho = torch.tensordot(Rho,E,([0],[0])) # Rho(ega)*EL(ehc)=Rho(gahc)
Rho = torch.tensordot(Rho,T,([0,2],[0,1])) # Rho(gahc)*T(ghdb)=Rho(acdb)
Rho = Rho.permute(0,3,1,2).contiguous().view(dimE*dimT, dimE*dimT) # Rho(acdb)->Rho(ab;cd)
Rho = Rho+Rho.t()
Rho = Rho/Rho.norm()
# step 2: Get Isometry P
U, S, V = svd(Rho)
truncation_error = S[D_new:].sum()/S.sum()
P = U[:, :D_new] # projection operator
#can also do symeig since Rho is symmetric
#S, U = symeig(Rho)
#sorted, indices = torch.sort(S.abs(), descending=True)
#truncation_error = sorted[D_new:].sum()/sorted.sum()
#S = S[indices][:D_new]
#P = U[:, indices][:, :D_new] # projection operator
# step 3: renormalize C and E
C = (P.t() @ Rho @ P) #C(D_new, D_new)
## EL(u,r,d)
P = P.view(dimE,dimT,D_new)
E = torch.tensordot(E, P, ([0],[0])) # EL(def)P(dga)=E(efga)
E = torch.tensordot(E, T, ([0,2],[1,0])) # E(efga)T(gehb)=E(fahb)
E = torch.tensordot(E, P, ([0,2],[0,1])) # E(fahb)P(fhc)=E(abc)
# step 4: symmetrize C and E
C = 0.5*(C+C.t())
E = 0.5*(E + E.permute(2, 1, 0))
return C/C.norm(), E, S.abs()/S.abs().max(), truncation_error
示例28
def loss(self, samples):
"""
Computes the Distributional Q-learning loss, based on projecting the
discounted rewards + target Q-distribution into the current Q-domain,
with cross-entropy loss.
Returns loss and KL-divergence-errors for use in prioritization.
"""
delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
# Makde 2-D tensor of contracted z_domain for each data point,
# with zeros where next value should not be added.
next_z = z * (self.discount ** self.n_step_return) # [P']
next_z = torch.ger(1 - samples.done_n.float(), next_z) # [B,P']
ret = samples.return_.unsqueeze(1) # [B,1]
next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P']
z_bc = z.view(1, -1, 1) # [1,P,1]
next_z_bc = next_z.unsqueeze(1) # [B,1,P']
abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0.
# projection_coeffs is a 3-D tensor: [B,P,P']
# dim-0: independent data entries
# dim-1: base_z atoms (remains after projection)
# dim-2: next_z atoms (summed in projection)
with torch.no_grad():
target_ps = self.agent.target(*samples.target_inputs) # [B,A,P']
if self.double_dqn:
next_ps = self.agent(*samples.target_inputs) # [B,A,P']
next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A]
next_a = torch.argmax(next_qs, dim=-1) # [B]
else:
target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A]
next_a = torch.argmax(target_qs, dim=-1) # [B]
target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P']
target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P']
target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P]
ps = self.agent(*samples.agent_inputs) # [B,A,P]
p = select_at_indexes(samples.action, ps) # [B,P]
p = torch.clamp(p, EPS, 1) # NaN-guard.
losses = -torch.sum(target_p * torch.log(p), dim=1) # Cross-entropy.
if self.prioritized_replay:
losses *= samples.is_weights
target_p = torch.clamp(target_p, EPS, 1)
KL_div = torch.sum(target_p *
(torch.log(target_p) - torch.log(p.detach())), dim=1)
KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard.
if not self.mid_batch_reset:
valid = valid_from_done(samples.done)
loss = valid_mean(losses, valid)
KL_div *= valid
else:
loss = torch.mean(losses)
return loss, KL_div
示例29
def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum.
"""
torch, _ = _get_torch_and_device()
if _TORCH_HAS_TENSORDOT:
return torch.tensordot(x, y, dims=axes)
xnd = x.ndimension()
ynd = y.ndimension()
# convert int argument to (list[int], list[int])
if isinstance(axes, int):
axes = range(xnd - axes, xnd), range(axes)
# convert (int, int) to (list[int], list[int])
if isinstance(axes[0], int):
axes = (axes[0], ), axes[1]
if isinstance(axes[1], int):
axes = axes[0], (axes[1], )
# initialize empty indices
x_ix = [None] * xnd
y_ix = [None] * ynd
out_ix = []
# fill in repeated indices
available_ix = iter(_torch_symbols_base)
for ax1, ax2 in zip(*axes):
repeat = next(available_ix)
x_ix[ax1] = repeat
y_ix[ax2] = repeat
# fill in the rest, and maintain output order
for i in range(xnd):
if x_ix[i] is None:
leave = next(available_ix)
x_ix[i] = leave
out_ix.append(leave)
for i in range(ynd):
if y_ix[i] is None:
leave = next(available_ix)
y_ix[i] = leave
out_ix.append(leave)
# form full string and contract!
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
return einsum(einsum_str, x, y)