Python源码示例:torch.remainder()
示例1
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例2
def convert_padding_direction(
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
buffered = torch.empty(0).long()
if max_len > 0:
torch.arange(max_len, out=buffered)
range = buffered.type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例3
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例4
def mod_divide(self, forget_radix, mask):
# Only called in reverse process.
mask = mask.long()
self.counter -= 1
buf_mod = torch.remainder(self.curr_buffer, 2**forget_radix).int()
self.curr_buffer = mask*(self.curr_buffer/(2**forget_radix)) + (1-mask)*self.curr_buffer
overflowed = self.overflow_detect.__and__(2**(self.counter % 8))[self.counter // 8]
if overflowed:
self.curr_buffer = self.past_buffers[:,:,-1]
if self.past_buffers.size(2) > 1:
self.past_buffers = self.past_buffers[:,:,:-1]
return buf_mod
###############################################################################
# Multiply/divide fixed point numbers with buffer
###############################################################################
示例5
def forward(ctx, h, z, buf, mask, slice_dim=0):
ctx.save_for_backward(h, z, mask)
# Shift buffer left, enlarging if needed, then store modulus of h in buffer.
if buf is not None:
h_mod = torch.remainder(h[:, slice_dim:], 2**forget_radix)
buf.overflow_mul(2**forget_radix, mask[:, slice_dim:])
buf.add(h_mod, mask[:, slice_dim:])
# Multiply h by z/(2**forget_radix).
# Have to do extra work in case h is negative.
sign_bits = h.__and__(sign_bit)
one_bits = negative_bits * -1 * torch.clamp(sign_bits, min=-1)
h = h.__rshift__(forget_radix * mask)
h = h.__or__(one_bits * mask)
h = mask*h*z + (1-mask)*h
# Store modulus of buffer in h then divide buffer by z.
if buf is not None:
buf_mod = buf.mod(z[:,slice_dim:])
h[:,slice_dim:] = h[:,slice_dim:] + buf_mod*mask[:, slice_dim:]
buf.div(z[:,slice_dim:], mask[:, slice_dim:])
return h
示例6
def forward(self, input, offsets=None, per_sample_weights=None):
input_q = (input / self.num_collisions).long()
input_r = torch.remainder(input, self.num_collisions).long()
embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode,
self.sparse, per_sample_weights)
embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode,
self.sparse, per_sample_weights)
if self.operation == 'concat':
embed = torch.cat((embed_q, embed_r), dim=1)
elif self.operation == 'add':
embed = embed_q + embed_r
elif self.operation == 'mult':
embed = embed_q * embed_r
return embed
示例7
def forward(self, input, offsets=None, per_sample_weights=None):
input_q = (input / self.num_collisions).long()
input_r = torch.remainder(input, self.num_collisions).long()
embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode,
self.sparse, per_sample_weights)
embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode,
self.sparse, per_sample_weights)
if self.operation == 'concat':
embed = torch.cat((embed_q, embed_r), dim=1)
elif self.operation == 'add':
embed = embed_q + embed_r
elif self.operation == 'mult':
embed = embed_q * embed_r
return embed
示例8
def convert_padding_direction(src_tokens,
padding_idx,
right_to_left=False,
left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例9
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例10
def convert_padding_direction(
src_tokens,
src_lengths,
padding_idx,
right_to_left=False,
left_to_right=False,
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if pad_mask.max() == 0:
# no padding, return early
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例11
def convert_padding_direction(
src_tokens,
src_lengths,
padding_idx,
right_to_left=False,
left_to_right=False,
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if pad_mask.max() == 0:
# no padding, return early
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例12
def convert_padding_direction(
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
buffered = torch.empty(0).long()
if max_len > 0:
torch.arange(max_len, out=buffered)
range = buffered.type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例13
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
示例14
def loss(self, outputs, labels, **_):
if self.model.training:
labels_flip = labels + self.num_classes // 2
labels_flip = torch.remainder(labels_flip, self.num_classes)
if labels_flip.dim() == 1:
labels_flip = labels_flip.unsqueeze(-1)
onehot = torch.zeros(outputs.size()).cuda()
onehot.scatter_(1, labels_flip, 1)
onehot_invert = (onehot == 0).float()
assert onehot_invert.size() == outputs.size()
outputs = outputs * onehot_invert - onehot_invert
return self.criterion(outputs, labels)
return torch.FloatTensor([0])
示例15
def reflect_conj_concat(kern, dim):
"""Reflects and conjugates kern before concatenating along dim.
Args:
kern (tensor): One half of a full, Hermitian-symmetric kernel.
dim (int): The integer across which to apply Hermitian symmetry.
Returns:
tensor: The full FFT kernel after Hermitian-symmetric reflection.
"""
dtype, device = kern.dtype, kern.device
dim = -1 - dim
flipdims = tuple(torch.arange(abs(dim)) + dim)
# calculate size of central z block
zblockshape = torch.tensor(kern.shape)
zblockshape[dim] = 1
zblock = torch.zeros(*zblockshape, dtype=dtype, device=device)
# conjugation array
conj_arr = torch.tensor([1, -1], dtype=dtype, device=device)
conj_arr = conj_arr.unsqueeze(0).unsqueeze(0)
while conj_arr.ndim < kern.ndim:
conj_arr = conj_arr.unsqueeze(-1)
# reflect the original block and conjugate it
tmpblock = conj_arr * kern
for d in flipdims:
tmpblock = tmpblock.index_select(
d,
torch.remainder(
-1 * torch.arange(tmpblock.shape[d], device=device), tmpblock.shape[d])
)
tmpblock = torch.cat(
(zblock, tmpblock.narrow(dim, 1, tmpblock.shape[dim]-1)), dim)
# concatenate and return
return torch.cat((kern, tmpblock), dim)
示例16
def hermitify(kern, dim):
"""Enforce Hermitian symmetry.
This function takes an approximately Hermitian-symmetric kernel and
enforces Hermitian symmetry by calcualting a tensor that reverses the
coordinates and conjugates the original, then averaging that tensor with
the original.
Args:
kern (tensor): An approximately Hermitian-symmetric kernel.
dim (int): The last imaging dimension.
Returns:
tensor: A Hermitian-symmetric kernel.
"""
dtype, device = kern.dtype, kern.device
dim = -1 - dim + kern.ndim
start = kern.clone()
# reverse coordinates for each dimension
for d in range(dim, kern.ndim):
kern = kern.index_select(
d,
torch.remainder(
-1 * torch.arange(kern.shape[d], device=device), kern.shape[d])
)
# conjugate
conj_arr = torch.tensor([1, -1], dtype=dtype, device=device)
conj_arr = conj_arr.unsqueeze(0).unsqueeze(0)
while conj_arr.ndim < kern.ndim:
conj_arr = conj_arr.unsqueeze(-1)
kern = conj_arr * kern
# take the average
kern = (start + kern) / 2
return kern
示例17
def mod(self, divisor):
divisor = divisor.long()
return torch.remainder(self.curr_buffer, divisor).int()
示例18
def forward(ctx, h, z, buf, mask, slice_dim=0):
buf.mul(z[:,slice_dim:], mask[:, slice_dim:])
h_mod = torch.remainder(h[:,slice_dim:], z[:,slice_dim:])
buf.add(h_mod, mask[:, slice_dim:])
h[h<0] = mask[h<0]*(h[h<0]-(z[h<0]-1)) + (1-mask[h<0])*h[h<0]
h = mask*(h / z) + (1-mask)*h
h = mask*(h * (2**forget_radix)) + (1-mask)*h
buf_mod = buf.mod_divide(forget_radix, mask[:, slice_dim:])
h[:,slice_dim:] = mask[:, slice_dim:]*(h[:,slice_dim:].__or__(buf_mod)) +\
(1-mask[:, slice_dim:])*h[:,slice_dim:]
return h
示例19
def __init__(self, num_categories, embedding_dim, num_collisions,
operation='mult', max_norm=None, norm_type=2.,
scale_grad_by_freq=False, mode='mean', sparse=False,
_weight=None):
super(QREmbeddingBag, self).__init__()
assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'
self.num_categories = num_categories
if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
self.embedding_dim = [embedding_dim, embedding_dim]
else:
self.embedding_dim = embedding_dim
self.num_collisions = num_collisions
self.operation = operation
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if self.operation == 'add' or self.operation == 'mult':
assert self.embedding_dim[0] == self.embedding_dim[1], \
'Embedding dimensions do not match!'
self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
num_collisions]
if _weight is None:
self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
self.reset_parameters()
else:
assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
self.weight_q = Parameter(_weight[0])
self.weight_r = Parameter(_weight[1])
self.mode = mode
self.sparse = sparse
示例20
def __init__(self, num_categories, embedding_dim, num_collisions,
operation='mult', max_norm=None, norm_type=2.,
scale_grad_by_freq=False, mode='mean', sparse=False,
_weight=None):
super(QREmbeddingBag, self).__init__()
assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'
self.num_categories = num_categories
if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
self.embedding_dim = [embedding_dim, embedding_dim]
else:
self.embedding_dim = embedding_dim
self.num_collisions = num_collisions
self.operation = operation
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if self.operation == 'add' or self.operation == 'mult':
assert self.embedding_dim[0] == self.embedding_dim[1], \
'Embedding dimensions do not match!'
self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
num_collisions]
if _weight is None:
self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
self.reset_parameters()
else:
assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
self.weight_q = Parameter(_weight[0])
self.weight_r = Parameter(_weight[1])
self.mode = mode
self.sparse = sparse
示例21
def log_uniform_sample(N, size):
log_N = math.log(N)
x = torch.Tensor(size).uniform_(0, 1)
value = torch.exp(x * log_N).long() - 1
return torch.remainder(value, N)
示例22
def log_uniform_sample(N, size):
log_N = math.log(N)
x = torch.Tensor(size).uniform_(0, 1)
value = torch.exp(x * log_N).long() - 1
return torch.remainder(value, N)
示例23
def fmod(t1, t2):
"""
Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. C Library function fmod), not commutative.
Takes the two operands (scalar or tensor, both may contain floating point number) whose elements are to be
divided (operand 1 by operand 2) as arguments.
Parameters
----------
t1: tensor or scalar
The first operand whose values are divided (may be floats)
t2: tensor or scalar
The second operand by whose values is divided (may be floats)
Returns
-------
result: ht.DNDarray
A tensor containing the remainder of the element-wise division (i.e. floating point values) of t1 by t2.
It has the sign as the dividend t1.
Examples:
---------
>>> import heat as ht
>>> ht.fmod(2.0, 2.0)
tensor([0.])
>>> T1 = ht.float32([[1, 2], [3, 4]])
>>> T2 = ht.float32([[2, 2], [2, 2]])
>>> ht.fmod(T1, T2)
tensor([[1., 0.],
[1., 0.]])
>>> s = 2.0
>>> ht.fmod(s, T1)
tensor([[0., 0.]
[2., 2.]])
"""
return operations.__binary_op(torch.fmod, t1, t2)
示例24
def mod(t1, t2):
"""
Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. t1 % t2), not commutative.
Takes the two operands (scalar or tensor) whose elements are to be divided (operand 1 by operand 2) as arguments.
Currently t1 and t2 are just passed to remainder.
Parameters
----------
t1: tensor or scalar
The first operand whose values are divided
t2: tensor or scalar
The second operand by whose values is divided
Returns
-------
result: ht.DNDarray
A tensor containing the remainder of the element-wise division of t1 by t2.
It has the same sign as the devisor t2.
Examples:
---------
>>> import heat as ht
>>> ht.mod(2, 2)
tensor([0])
>>> T1 = ht.int32([[1, 2], [3, 4]])
>>> T2 = ht.int32([[2, 2], [2, 2]])
>>> ht.mod(T1, T2)
tensor([[1, 0],
[1, 0]], dtype=torch.int32)
>>> s = 2
>>> ht.mod(s, T1)
tensor([[0, 0]
[2, 2]], dtype=torch.int32)
"""
return remainder(t1, t2)
示例25
def remainder(t1, t2):
"""
Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. t1 % t2), not commutative.
Takes the two operands (scalar or tensor) whose elements are to be divided (operand 1 by operand 2) as arguments.
Parameters
----------
t1: tensor or scalar
The first operand whose values are divided
t2: tensor or scalar
The second operand by whose values is divided
Returns
-------
result: ht.DNDarray
A tensor containing the remainder of the element-wise division of t1 by t2.
It has the same sign as the devisor t2.
Examples:
---------
>>> import heat as ht
>>> ht.mod(2, 2)
tensor([0])
>>> T1 = ht.int32([[1, 2], [3, 4]])
>>> T2 = ht.int32([[2, 2], [2, 2]])
>>> ht.mod(T1, T2)
tensor([[1, 0],
[1, 0]], dtype=torch.int32)
>>> s = 2
>>> ht.mod(s, T1)
tensor([[0, 0]
[2, 2]], dtype=torch.int32)
"""
return operations.__binary_op(torch.remainder, t1, t2)
示例26
def log_uniform_sample(N, size):
log_N = math.log(N)
x = torch.Tensor(size).uniform_(0, 1)
value = torch.exp(x * log_N).long() - 1
return torch.remainder(value, N)
示例27
def log_uniform_sample(N, size):
log_N = math.log(N)
x = torch.Tensor(size).uniform_(0, 1)
value = torch.exp(x * log_N).long() - 1
return torch.remainder(value, N)
示例28
def log_uniform_sample(N, size):
log_N = math.log(N)
x = torch.Tensor(size).uniform_(0, 1)
value = torch.exp(x * log_N).long() - 1
return torch.remainder(value, N)
示例29
def calc_coef_and_indices(tm, kofflist, Jval, table, centers, L, dims, conjcoef=False):
"""Calculates interpolation coefficients and on-grid indices.
Args:
tm (tensor): normalized frequency locations.
kofflist (tensor): A tensor with offset locations to first elements in
list of nearest neighbords.
Jval (tensor): A tuple-like tensor for how much to increment offsets.
table (list): A list of tensors tabulating a Kaiser-Bessel
interpolation kernel.
centers (tensor): A tensor with the center locations of the table for
each dimension.
L (tensor): A tensor with the table size in each dimension.
dims (tensor): A tensor with image dimensions.
conjcoef (boolean, default=False): A boolean for whether to compute
normal or complex conjugate interpolation coefficients
(conjugate needed for adjoint).
Returns:
tuple: A tuple with interpolation coefficients and indices.
"""
# type values
dtype = tm.dtype
device = tm.device
int_type = torch.long
# array shapes
M = tm.shape[1]
ndims = tm.shape[0]
# indexing locations
gridind = (kofflist + Jval.unsqueeze(1)).to(dtype)
distind = torch.round(
(tm - gridind) * L.unsqueeze(1)).to(dtype=int_type)
gridind = gridind.to(int_type)
arr_ind = torch.zeros((M,), dtype=int_type, device=device)
coef = torch.stack((
torch.ones(M, dtype=dtype, device=device),
torch.zeros(M, dtype=dtype, device=device)
))
for d in range(ndims): # spatial dimension
if conjcoef:
coef = conj_complex_mult(
coef,
table[d][:, distind[d, :] + centers[d]],
dim=0
)
else:
coef = complex_mult(
coef,
table[d][:, distind[d, :] + centers[d]],
dim=0
)
arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
torch.prod(dims[d + 1:])
return coef, arr_ind
示例30
def _val_epoch(self, epoch):
self.model.eval()
val_std_loss = Metric('val_std_loss')
val_std_acc = Metric('val_std_acc')
val_adv_acc = Metric('val_adv_acc')
val_adv_loss = Metric('val_adv_loss')
val_max_adv_acc = Metric('val_max_adv_acc')
val_max_adv_loss = Metric('val_max_adv_loss')
for batch_idx, (data, target) in enumerate(self.val_loader):
if self.cuda:
data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
with torch.no_grad():
output = self.model(data)
val_std_loss.update(F.cross_entropy(output, target))
val_std_acc.update(accuracy(output, target))
if self.attack:
rand_target = torch.randint(
0, len(self.val_dataset.classes) - 1, target.size(),
dtype=target.dtype, device='cuda')
rand_target = torch.remainder(target + rand_target + 1, len(self.val_dataset.classes))
data_adv = self.attack(self.model, data, rand_target,
avoid_target=False, scale_eps=self.scale_eps)
data_max_adv = self.attack(self.model, data, rand_target, avoid_target=False, scale_eps=False)
with torch.no_grad():
output_adv = self.model(data_adv)
val_adv_loss.update(F.cross_entropy(output_adv, target))
val_adv_acc.update(accuracy(output_adv, target))
output_max_adv = self.model(data_max_adv)
val_max_adv_loss.update(F.cross_entropy(output_max_adv, target))
val_max_adv_acc.update(accuracy(output_max_adv, target))
self.model.eval()
if hvd.rank() == 0:
log_dict = {'val_std_loss':val_std_loss.avg.item(),
'val_std_acc':val_std_acc.avg.item(),
'val_adv_loss':val_adv_loss.avg.item(),
'val_adv_acc':val_adv_acc.avg.item(),
'val_adv_loss':val_max_adv_loss.avg.item(),
'val_max_adv_acc':val_max_adv_acc.avg.item()}
self.logger.log(log_dict, epoch)
if self.verbose:
print(log_dict)
self.optimizer.synchronize()
self.optimizer.zero_grad()