Python源码示例:torch.repeat_interleave()
示例1
def hard_k_hot(logits, k, temperature=0.1):
r"""Returns a hard k-hot sample given a categorical
distribution defined by a tensor of unnormalized
log-likelihoods.
This is useful for example to sample a set of pixels in an
image to move from a grid-structured data representation to a
set- or graph-structured representation within a network.
Args:
logits (torch.Tensor): unnormalized log-likelihood tensor.
k (int): number of items to sample without replacement.
temperature (float): temparature of the soft distribution.
Returns:
Hard k-hot vector from the relaxed k-hot distribution
defined by logits and temperature.
"""
soft = soft_k_hot(logits, k, temperature=temperature)
hard = torch.zeros_like(soft)
_, top_k = torch.topk(logits, k)
index = torch.repeat_interleave(torch.arange(0, hard.size(0)), k)
hard[index, top_k.view(-1)] = 1.0
return replace_gradient(hard, soft)
示例2
def pairwise_no_pad(op, data, indices):
unique, counts = indices.unique(return_counts=True)
expansion = torch.cumsum(counts, dim=0)
expansion = torch.repeat_interleave(expansion, counts)
offset = torch.arange(0, counts.sum(), device=data.device)
expansion = expansion - offset - 1
expanded = torch.repeat_interleave(data, expansion.to(data.device), dim=0)
expansion_offset = counts.roll(1)
expansion_offset[0] = 0
expansion_offset = expansion_offset.cumsum(dim=0)
expansion_offset = torch.repeat_interleave(expansion_offset, counts)
expansion_offset = torch.repeat_interleave(expansion_offset, expansion)
off_start = torch.repeat_interleave(torch.repeat_interleave(counts, counts) - expansion, expansion)
access = torch.arange(expansion.sum(), device=data.device)
access = access - torch.repeat_interleave(expansion.roll(1).cumsum(dim=0), expansion) + off_start + expansion_offset
result = op(expanded, data[access.to(data.device)])
return result, torch.repeat_interleave(indices, expansion, dim=0)
示例3
def __patched_conv_ops(op, x, y, *args, **kwargs):
x_encoded = CUDALongTensor.__encode_as_fp64(x).data
y_encoded = CUDALongTensor.__encode_as_fp64(y).data
repeat_idx = [1] * (x_encoded.dim() - 1)
x_enc_span = x_encoded.repeat(3, *repeat_idx)
y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)
bs, c, *img = x.size()
c_out, c_in, *ks = y.size()
x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img)
y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks)
c_z = c_out if op in ["conv1d", "conv2d"] else c_in
z_encoded = getattr(torch, op)(
x_enc_span, y_enc_span, *args, **kwargs, groups=9
)
z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_(
0, 1
)
return CUDALongTensor.__decode_as_int64(z_encoded)
示例4
def get_tiled_batch(self, num_tiles: int):
assert (
self.has_float_features_only
), f"only works for float features now: {self}"
"""
tiled_feature should be (batch_size * num_tiles, feature_dim)
forall i in [batch_size],
tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
"""
feat = self.float_features
assert (
len(feat.shape) == 2
), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
batch_size, _ = feat.shape
# pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
return FeatureData(float_features=tiled_feat)
示例5
def get_output(encoder_output, duration_predictor_output, alpha, mel_max_length=None):
output = list()
dec_pos = list()
for i in range(encoder_output.size(0)):
repeats = duration_predictor_output[i].float() * alpha
repeats = torch.round(repeats).long()
output.append(torch.repeat_interleave(encoder_output[i], repeats, dim=0))
dec_pos.append(torch.from_numpy(np.indices((output[i].shape[0],))[0] + 1))
output = torch.nn.utils.rnn.pad_sequence(output, batch_first=True)
dec_pos = torch.nn.utils.rnn.pad_sequence(dec_pos, batch_first=True)
dec_pos = dec_pos.to(output.device, non_blocking=True)
if mel_max_length:
output = output[:, :mel_max_length]
dec_pos = dec_pos[:, :mel_max_length]
return output, dec_pos
示例6
def test_instance_norm():
batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
norm = InstanceNorm(16)
assert norm.__repr__() == (
'InstanceNorm(16, eps=1e-05, momentum=0.1, affine=False, '
'track_running_stats=False)')
out = norm(torch.randn(100, 16), batch)
assert out.size() == (100, 16)
norm = InstanceNorm(16, affine=True, track_running_stats=True)
out = norm(torch.randn(100, 16), batch)
assert out.size() == (100, 16)
# Should behave equally to `BatchNorm` for mini-batches of size 1.
x = torch.randn(100, 16)
norm1 = InstanceNorm(16, affine=False, track_running_stats=False)
norm2 = BatchNorm(16, affine=False, track_running_stats=False)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
norm1 = InstanceNorm(16, affine=False, track_running_stats=True)
norm2 = BatchNorm(16, affine=False, track_running_stats=True)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
norm1.eval()
norm2.eval()
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
示例7
def test_graph_size_norm():
batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
norm = GraphSizeNorm()
out = norm(torch.randn(100, 16), batch)
assert out.size() == (100, 16)
示例8
def forward(self, seq_value_len_list):
if self.supports_masking:
uiseq_embed_list, mask = seq_value_len_list # [B, T, E], [B, 1]
mask = mask.float()
user_behavior_length = torch.sum(mask, dim=-1, keepdim=True)
mask = mask.unsqueeze(2)
else:
uiseq_embed_list, user_behavior_length = seq_value_len_list # [B, T, E], [B, 1]
mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1],
dtype=torch.float32) # [B, 1, maxlen]
mask = torch.transpose(mask, 1, 2) # [B, maxlen, 1]
embedding_size = uiseq_embed_list.shape[-1]
mask = torch.repeat_interleave(mask, embedding_size, dim=2) # [B, maxlen, E]
if self.mode == 'max':
hist = uiseq_embed_list - (1 - mask) * 1e9
hist = torch.max(hist, dim=1, keepdim=True)[0]
return hist
hist = uiseq_embed_list * mask.float()
hist = torch.sum(hist, dim=1, keepdim=False)
if self.mode == 'mean':
hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps)
hist = torch.unsqueeze(hist, dim=1)
return hist
示例9
def repeat(input, repeats, dim):
# return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
if dim < 0:
dim += input.dim()
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
示例10
def sample(self, data):
support, values = data
mean, logvar = self.condition(support)
distribution = Normal(mean, torch.exp(0.5 * logvar))
latent_sample = distribution.rsample()
latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
local_samples = torch.randn(support.size(0) * self.size, 16)
sample = torch.cat((latent_sample, local_samples), dim=1)
return (support, sample), (mean, logvar)
示例11
def forward(self, data):
support, values = data
mean, logvar = self.encoder(support)
distribution = Normal(mean, torch.exp(0.5 * logvar))
latent_sample = distribution.rsample()
latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
combined = torch.cat((values.view(-1, 28 * 28), latent_sample), dim=1)
return self.verdict(combined)
示例12
def forward(self, image, condition):
image = image.view(-1, 3, 64, 64)
out = self.input_process(self.input(image))
mean, logvar = self.condition(condition)
#distribution = Normal(mean, torch.exp(0.5 * logvar))
sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
cond = self.postprocess(sample)
cond = torch.repeat_interleave(cond, 5, dim=0)
result = self.combine(torch.cat((out, cond), dim=1))
return result, (mean, logvar)
示例13
def forward(self, image, condition):
image = image.view(-1, 28 * 28)
out = self.input_process(self.input(image))
mean, logvar = self.condition(condition)
#distribution = Normal(mean, torch.exp(0.5 * logvar))
sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
cond = self.postprocess(sample)
cond = torch.repeat_interleave(cond, 5, dim=0)
result = self.combine(torch.cat((out, cond), dim=1))
return result, (mean, logvar)
示例14
def repack(data, indices, target_indices):
out = torch.zeros(
target_indices.size(0), *data.shape[1:],
dtype=data.dtype, device=data.device
)
unique, lengths = indices.unique(return_counts=True)
unique, target_lengths = target_indices.unique(return_counts=True)
offset = target_lengths - lengths
offset = offset.roll(1, 0)
offset[0] = 0
offset = torch.repeat_interleave(offset.cumsum(dim=0), lengths, dim=0)
index = offset + torch.arange(len(indices)).to(data.device)
out[index] = data
return data, target_indices
示例15
def pairwise(op, data, indices, padding_value=0):
padded, _, _, counts = pad(data, indices, value=padding_value)
padded = padded.transpose(1, 2)
reference = padded.unsqueeze(-1)
padded = padded.unsqueeze(-2)
op_result = op(padded, reference)
# batch indices into pairwise tensor:
batch_indices = torch.arange(counts.size(0))
batch_indices = torch.repeat_interleave(batch_indices, counts ** 2)
# first dimension indices:
first_offset = counts.roll(1)
first_offset[0] = 0
first_offset = torch.cumsum(first_offset, dim=0)
first_offset = torch.repeat_interleave(first_offset, counts)
first_indices = torch.arange(counts.sum()) - first_offset
first_indices = torch.repeat_interleave(
first_indices,
torch.repeat_interleave(counts, counts)
)
# second dimension indices:
second_offset = torch.repeat_interleave(counts, counts).roll(1)
second_offset[0] = 0
second_offset = torch.cumsum(second_offset, dim=0)
second_offset = torch.repeat_interleave(second_offset, torch.repeat_interleave(counts, counts))
second_indices = torch.arange((counts ** 2).sum()) - second_offset
# extract tensor from padded result using indices:
result = op_result[batch_indices, first_indices, second_indices]
# access: cumsum(counts ** 2)[idx] + counts[idx] * idy + idz
access_batch = (counts ** 2).roll(1)
access_batch[0] = 0
access_batch = torch.cumsum(access_batch, dim=0)
access_first = counts
access = (access_batch, access_first)
return result, batch_indices, first_indices, second_indices, access
示例16
def __init__(self, indices):
unique, counts = indices.unique(return_counts=True)
structure_indices = torch.arange(counts.sum(), device=indices.device)
structure_indices = torch.repeat_interleave(
structure_indices, torch.repeat_interleave(
counts, counts
)
)
# prepare offsets of connections:
repeated_counts = torch.repeat_interleave(counts, counts)
other_counts = repeated_counts.roll(1)
other_counts[0] = 0
other_counts = other_counts.cumsum(dim=0)
offset_factors = other_counts
offset = torch.repeat_interleave(offset_factors, repeated_counts)
base = counts.roll(1)
base[0] = 0
base = base.cumsum(dim=0)
base = torch.repeat_interleave(torch.repeat_interleave(base, counts), repeated_counts)
structure_connections = torch.arange((counts * counts).sum(), device=indices.device)
structure_connections = structure_connections - offset + base
super(FullyConnectedScatter, self).__init__(
0, 0,
structure_indices,
structure_connections
)
示例17
def __init__(self, batch, width):
# prepare offsets of connections:
offset_factors = torch.arange(width * batch) * width
offset = torch.repeat_interleave(offset_factors, width)
structure_connections = torch.arange(width * width * batch) - offset
structure_connections = structure_connections.reshape(batch * width, width)
super(FullyConnectedConstant, self).__init__(
0, 0,
structure_connections
)
示例18
def matmul(x, y, *args, **kwargs):
# Prepend 1 to the dimension of x or y if it is 1-dimensional
remove_x, remove_y = False, False
if x.dim() == 1:
x = x.view(1, x.shape[0])
remove_x = True
if y.dim() == 1:
y = y.view(y.shape[0], 1)
remove_y = True
x_encoded = CUDALongTensor.__encode_as_fp64(x).data
y_encoded = CUDALongTensor.__encode_as_fp64(y).data
# Span x and y for cross multiplication
repeat_idx = [1] * (x_encoded.dim() - 1)
x_enc_span = x_encoded.repeat(3, *repeat_idx)
y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)
# Broadcasting
for _ in range(abs(x_enc_span.ndim - y_enc_span.ndim)):
if x_enc_span.ndim > y_enc_span.ndim:
y_enc_span.unsqueeze_(1)
else:
x_enc_span.unsqueeze_(1)
z_encoded = torch.matmul(x_enc_span, y_enc_span, *args, **kwargs)
if remove_x:
z_encoded.squeeze_(-2)
if remove_y:
z_encoded.squeeze_(-1)
return CUDALongTensor.__decode_as_int64(z_encoded)
示例19
def extract_video(self, img):
buffer = self.transform(img)
buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 1), self.clip_len, 1)
buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 0), self.n_clips, 0)
return buffer
示例20
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True):
assert batch_sorted_labels.size(1) > k
batch_max = torch.max(batch_sorted_labels, dim=1)
batch_labels = batch_sorted_labels[:, 0:k]
batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(2.0, batch_max)
batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros
batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1)
positions = torch.arange(k) + 1.0
positions = positions.view(1, -1)
positions = torch.repeat_interleave(positions, batch_sorted_labels.size(0), dim=0)
batch_expt_ranks = 1.0 / positions
cascad_unsatis_pros = positions
cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k-1]
expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros # w.r.t. all rank positions
if point:
batch_errs = torch.sum(expt_satis_ranks, dim=1)
return batch_errs
else:
batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1)
return batch_err_at_ks
示例21
def __init__(
self, data_encoder=None, data_transform=None, lbl_transform=None, repeats=1
):
self.data_encoder = data_encoder
self.data_transform = data_transform
self.lbl_transform = lbl_transform
self.data = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.float)
self.data = torch.repeat_interleave(self.data, int(repeats), dim=1)
示例22
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
Output: indices for all central atoms and it pairs of neighbors. For
example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
(1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
# convert representation from pair to central-others
ai1 = atom_index12.view(-1)
sorted_ai1, rev_indices = ai1.sort()
# sort and compute unique key
uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True)
# compute central_atom_index
pair_sizes = counts * (counts - 1) // 2
pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
# do local combinations within unique key, assuming sorted
m = counts.max().item() if counts.numel() > 0 else 0
n = pair_sizes.shape[0]
intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)
# unsort result from last part
local_index12 = rev_indices[sorted_local_index12]
# compute mapping between representation of central-other to pair
n = atom_index12.shape[1]
sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
return central_atom_index, local_index12 % n, sign12
示例23
def sample_weights(self) -> Tensor:
if self._sample_weights is None:
samples = self.queries[:, 2]
self._sample_weights = torch.repeat_interleave(
samples.to(dtype=torch.float).reciprocal(), samples.to(dtype=torch.long)
)
return self._sample_weights
示例24
def select_slate(self, action: torch.Tensor):
row_idx = torch.repeat_interleave(
torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1
)
mask = self.mask[row_idx, action]
# Make sure the indices are in the right range
assert mask.to(torch.bool).all()
float_features = self.float_features[row_idx, action]
value = self.value[row_idx, action]
return DocList(float_features, mask, value)
示例25
def _generate_text_rep(text, dur):
text_rep = []
for t, d in zip(text, dur):
text_rep.append(torch.repeat_interleave(t, d))
text_rep = Ops.merge(text_rep)
return text_rep
示例26
def forward(self, inputs):
feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
for i in range(self.n_blocks-1):
feats.append(self.backbone[i](feats[-1]))
feats = torch.cat(feats, 1)
fusion = self.fusion_block(feats)
x1 = F.adaptive_max_pool2d(fusion, 1)
x2 = F.adaptive_avg_pool2d(fusion, 1)
feat_global_pool = torch.cat((x1, x2), dim=1)
feat_global_pool = torch.repeat_interleave(feat_global_pool, repeats=fusion.shape[2], dim=2)
cat_pooled = torch.cat((feat_global_pool, fusion), dim=1)
out = self.prediction(cat_pooled).squeeze(-1)
return F.log_softmax(out, dim=1)
示例27
def forward(self, inputs):
feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
for i in range(self.n_blocks-1):
feats.append(self.backbone[i](feats[-1]))
feats = torch.cat(feats, dim=1)
fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]])
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1)
示例28
def forward(self, data):
corr, color, batch = data.pos, data.x, data.batch
x = torch.cat((corr, color), dim=1)
feats = [self.head(x, self.knn(x[:, 0:3], batch))]
for i in range(self.n_blocks-1):
feats.append(self.backbone[i](feats[-1], batch)[0])
feats = torch.cat(feats, dim=1)
fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
return self.prediction(torch.cat((fusion, feats), dim=1))
示例29
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
# Variational Auto-Encoder Training
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
# Critic Training
with torch.no_grad():
# Duplicate next state 10 times
next_state = torch.repeat_interleave(next_state, 10, 0)
# Compute value of perturbed actions sampled from the VAE
target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state)))
# Soft Clipped Double Q-learning
target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
# Take max over each action sampled from the VAE
target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)
target_Q = reward + not_done * self.discount * target_Q
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Pertubation Model / Action Training
sampled_actions = self.vae.decode(state)
perturbed_actions = self.actor(state, sampled_actions)
# Update through DPG
actor_loss = -self.critic.q1(state, perturbed_actions).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update Target Networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
示例30
def __call__(self, img):
"""
Args:
img: shape must be (batch_size, C, spatial_dim1[, spatial_dim2, ...]).
Returns:
A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]).
"""
channel_dim = 1
if img.shape[channel_dim] == 1:
img = torch.squeeze(img, dim=channel_dim)
if self.independent:
for i in self.applied_labels:
foreground = (img == i).type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[foreground != mask] = 0
else:
foreground = torch.zeros_like(img)
for i in self.applied_labels:
foreground += (img == i).type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[foreground != mask] = 0
output = torch.unsqueeze(img, dim=channel_dim)
else:
# one-hot data is assumed to have binary value in each channel
if self.independent:
for i in self.applied_labels:
foreground = img[:, i, ...].type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[:, i, ...][foreground != mask] = 0
else:
applied_img = img[:, self.applied_labels, ...].type(torch.uint8)
foreground = torch.any(applied_img, dim=channel_dim)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
background_mask = torch.unsqueeze(foreground != mask, dim=channel_dim)
background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=channel_dim)
applied_img[background_mask] = 0
img[:, self.applied_labels, ...] = applied_img.type(img.type())
output = img
return output