Python源码示例:torch.jit()
示例1
def __init__(self, models, tgt_dict, max_iter=1, quantize=True, check_trace=True):
super().__init__()
src_tokens = torch.tensor([[4, 2]])
src_lengths = torch.tensor([2])
self.models = models
generator = IterativeRefinementGenerator(
self.models, tgt_dict, max_iter=max_iter
)
if quantize:
generator = torch.quantization.quantize_dynamic(
generator, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
)
enc_inputs = (src_tokens, src_lengths)
self.generator = torch.jit.trace(
generator, enc_inputs, _force_outplace=True, check_trace=check_trace
)
示例2
def conv_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for convolution using torch script.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object before convolution.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object after convolution.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Inputs of Convolution should be a list of length 12. They represent:
# 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding,
# 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn,
# 10) deterministic_cudnn and 11) user_enabled_cudnn.
assert len(inputs) == 12, len(inputs)
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
return conv_flop_count(x_shape, w_shape, out_shape)
示例3
def batchnorm_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for batch norm.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object before batch norm.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object after batch norm.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Inputs[0] contains the shape of the input.
input_shape = get_shape(inputs[0])
assert 2 <= len(input_shape) <= 5, input_shape
flop = prod(input_shape) * 4
flop_counter = Counter({"batchnorm": flop})
return flop_counter
示例4
def forward(self, enc, dec):
""" Forward pass
Arguments:
enc: Tensor from the encoder pathway
dec: Tensor from the decoder pathway (to be upconv'd)
"""
updec = self.upconv(dec)
enc, updec = autocrop(enc, updec)
genc, att = self.attention(enc, dec)
if not torch.jit.is_scripting():
self.att = att
updec = self.norm0(updec)
updec = self.act0(updec)
if self.merge_mode == 'concat':
mrg = torch.cat((updec, genc), 1)
else:
mrg = updec + genc
y = self.conv1(mrg)
y = self.norm1(y)
y = self.act1(y)
y = self.conv2(y)
y = self.norm2(y)
y = self.act2(y)
return y
示例5
def __init__(self,
module,
example_inputs):
super().__init__()
self.module = module
is_class = isinstance(module, torch.nn.Module)
trace = torch.jit.trace(module, example_inputs, True)
if not isinstance(example_inputs, (list, tuple)):
example_inputs = [example_inputs]
graph_py = parse(
trace.graph, len(example_inputs), omit_useless_nodes=False, is_class=is_class)
self.graph = graph_py
self.trace = trace
self.example_inputs = example_inputs
msg = "input mismatch. this may due to some input isn't used in graph"
assert len(example_inputs) + int(is_class) == len(graph_py.get_input_nodes_dict()), msg
示例6
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
i += 1
return output, output_states
# Differs from StackedLSTM in that its forward method takes
# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
# except we don't support overriding script methods.
# https://github.com/pytorch/pytorch/issues/10733
示例7
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
# Apply the dropout layer except the last layer
if i < self.num_layers - 1:
output = self.dropout_layer(output)
output_states += [out_state]
i += 1
return output, output_states
示例8
def import_model(path=None):
"""
Imports a model (as ScriptModule) from file.
Parameters
----------
path : str
Path to where the model is saved. Defaults to the return value of the `get_model_path`
function above.
Returns
-------
torch.jit.ScriptModule
The model file.
"""
path = get_model_path() if path is None else path
return torch.jit.load(path)
示例9
def forward(self, src_tokens, src_lengths):
# (seq_length, batch_size) for compatibility with Caffe2
src_tokens_seq_first = src_tokens.t()
futures = []
for model in self.models:
# evaluation mode
model.eval()
futures.append(
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
)
return self.get_outputs(src_tokens, futures)
示例10
def save_to_pytorch(self, output_path):
def pack(s):
if hasattr(s, "_pack"):
s._pack()
def unpack(s):
if hasattr(s, "_unpack"):
s._unpack()
self.apply(pack)
torch.jit.save(self, output_path)
self.apply(unpack)
示例11
def forward(
self,
src_tokens: torch.Tensor,
src_lengths: torch.Tensor,
prev_token: torch.Tensor,
prev_scores: torch.Tensor,
attn_weights: torch.Tensor,
prev_hypos_indices: torch.Tensor,
num_steps: int,
) -> List[Tuple[Tensor, float, List[float], Tensor, Tensor]]:
beam_search_out = self.beam_search(
src_tokens,
src_lengths,
prev_token,
prev_scores,
attn_weights,
prev_hypos_indices,
num_steps,
)
all_tokens, all_scores, all_weights, all_prev_indices = beam_search_out
outputs = torch.jit.annotate(
List[Tuple[Tensor, float, List[float], Tensor, Tensor]], []
)
outputs = self.beam_decode(
all_tokens, all_scores, all_weights, all_prev_indices, num_steps
)
return outputs
示例12
def save_to_pytorch(self, output_path):
def pack(s):
if hasattr(s, "_pack"):
s._pack()
def unpack(s):
if hasattr(s, "_unpack"):
s._unpack()
self.apply(pack)
torch.jit.save(self, output_path)
self.apply(unpack)
示例13
def _get_all_end_states(
self,
beam_tokens: Tensor,
beam_scores: Tensor,
beam_prev_indices: Tensor,
num_steps: int,
) -> Tensor:
min_score = float("inf")
min_index = -1
end_states = torch.jit.annotate(List[Tensor], [])
position = 1
while bool(position <= num_steps + 1):
for hyp_index in range(self.beam_size):
if bool(beam_tokens[position][hyp_index] == self.eos_token_id) or bool(
position == num_steps + 1
):
hypo_score = float(beam_scores[position][hyp_index])
if bool(self.length_penalty != 0):
hypo_score = hypo_score / float(position) ** float(
self.length_penalty
)
end_states, min_score, min_index = self._add_to_end_states(
end_states,
min_score,
torch.tensor([hypo_score, float(position), float(hyp_index)]),
min_index,
)
position = position + 1
end_states = torch.stack(end_states)
_, sorted_end_state_indices = end_states[:, 0].sort(dim=0, descending=True)
end_states = end_states[sorted_end_state_indices, :]
return end_states
示例14
def save_to_pytorch(self, output_path):
def pack(s):
if hasattr(s, "_pack"):
s._pack()
def unpack(s):
if hasattr(s, "_unpack"):
s._unpack()
self.apply(pack)
torch.jit.save(self, output_path)
self.apply(unpack)
示例15
def generic_activation_jit(
op_name: str,
) -> typing.Callable[[typing.List[object], typing.List[object]], typing.Counter[str]]:
"""
This method return a handle that counts the number of activation from the
output shape for the specified operation.
Args:
op_name (str): The name of the operation.
Returns:
typing.Callable: An activation handle for the given operation.
"""
def _generic_activation_jit(outputs: typing.List[object]) -> int:
"""
This is a generic jit handle that counts the number of activations for any
operation given the output shape.
Args:
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object.
Returns:
int: Total number of activations for each operation.
"""
out_shape = get_shape(outputs[0])
ac_count = prod(out_shape)
return ac_count
return lambda inputs, outputs: Counter({op_name: _generic_activation_jit(outputs)})
示例16
def get_shape(val: object) -> typing.List[int]:
"""
Get the shapes from a jit value object.
Args:
val (torch._C.Value): jit value object.
Returns:
list(int): return a list of ints.
"""
if val.isCompleteTensor(): # pyre-ignore
return val.type().sizes() # pyre-ignore
else:
raise ValueError()
示例17
def addmm_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for fully connected layers with torch script.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flop = batch_size * input_dim * output_dim
flop_counter = Counter({"addmm": flop})
return flop_counter
示例18
def matmul_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for matmul.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object before matmul.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object after matmul.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert len(input_shapes[1]) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][0], input_shapes
batch_dim = input_shapes[0][0]
m1_dim, m2_dim = input_shapes[1]
flop = m1_dim * m2_dim * batch_dim
flop_counter = Counter({"matmul": flop})
return flop_counter
示例19
def __call__(self, *args, **kwargs):
method_model = _ForwardOverrideModel(self.model, self.method_name)
example_inputs = {
self.method_name: kwargs if len(kwargs) > 0 else args
}
# noinspection PyTypeChecker
self.tracing_result = torch.jit.trace(
method_model, example_inputs=example_inputs
)
示例20
def test_cutmix():
input = torch.empty(4, 3, 32, 32)
target = torch.tensor([1, 2, 3, 4], dtype=torch.long)
cutmix(input, target, 0.1)
jit_cutmix = torch.jit.script(cutmix)
jit_cutmix(input, target, 0.1)
示例21
def test_cutmix():
input = torch.empty(4, 3, 32, 32)
target = torch.tensor([1, 2, 3, 4], dtype=torch.long)
mixup(input, target, 0.1)
jit_mixup = torch.jit.script(mixup)
jit_mixup(input, target, 0.1)
示例22
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
inputs = reverse(input.unbind(0))
outputs = jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(reverse(outputs)), state
示例23
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: [forward LSTMState, backward LSTMState]
outputs = jit.annotate(List[Tensor], [])
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for direction in self.directions:
state = states[i]
out, out_state = direction(input, state)
outputs += [out]
output_states += [out_state]
i += 1
return torch.cat(outputs, -1), output_states
示例24
def forward(self, input, states):
# type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
# List[List[LSTMState]]: The outer list is for layers,
# inner list is for directions.
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
i += 1
return output, output_states
示例25
def export_model(model, path=None, input_shape=(1, 3, 64, 64)):
"""
Exports the model. If the model is a `ScriptModule`, it is saved as is. If not,
it is traced (with the given input_shape) and the resulting ScriptModule is saved
(this requires the `input_shape`, which defaults to the competition default).
Parameters
----------
model : torch.nn.Module or torch.jit.ScriptModule
Pytorch Module or a ScriptModule.
path : str
Path to the file where the model is saved. Defaults to the value set by the
`get_model_path` function above.
input_shape : tuple or list
Shape of the input to trace the module with. This is only required if model is not a
torch.jit.ScriptModule.
Returns
-------
str
Path to where the model is saved.
"""
path = get_model_path() if path is None else path
model = deepcopy(model).cpu().eval()
if not isinstance(model, torch.jit.ScriptModule):
assert input_shape is not None, "`input_shape` must be provided since model is not a " \
"`ScriptModule`."
traced_model = trace(model, torch.zeros(*input_shape))
else:
traced_model = model
torch.jit.save(traced_model, path)
return path
示例26
def make_representor(model, cuda=None):
"""
Encloses the pytorch ScriptModule in a callable that can be used by `disentanglement_lib`.
Parameters
----------
model : torch.nn.Module or torch.jit.ScriptModule
The Pytorch model.
cuda : bool
Whether to use CUDA for inference. Defaults to the return value of the `use_cuda`
function defined above.
Returns
-------
callable
A callable function (`representation_function` in dlib code)
"""
# Deepcopy doesn't work on ScriptModule objects yet:
# https://github.com/pytorch/pytorch/issues/18106
# model = deepcopy(model)
cuda = use_cuda() if cuda is None else cuda
model = model.cuda() if cuda else model.cpu()
# Define the representation function
def _represent(x):
assert isinstance(x, np.ndarray), \
"Input to the representation function must be a ndarray."
assert x.ndim == 4, \
"Input to the representation function must be a four dimensional NHWC tensor."
# Convert from NHWC to NCHW
x = np.moveaxis(x, 3, 1)
# Convert to torch tensor and evaluate
x = torch.from_numpy(x).float().to('cuda' if cuda else 'cpu')
with torch.no_grad():
y = model(x)
y = y.cpu().numpy()
assert y.ndim == 2, \
"The returned output from the representor must be two dimensional (NC)."
return y
return _represent
示例27
def forward(self, inputs, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
outputs = jit.annotate(List[Tensor], [])
seq_len = inputs.size(0)
for i in range(seq_len):
out, state = self.cell(inputs[seq_len - i - 1], state)
# workaround for the lack of list rev support
outputs = [out] + outputs
return torch.stack(outputs), state
示例28
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: [forward LSTMState, backward LSTMState]
outputs = jit.annotate(List[Tensor], [])
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
for (i, direction) in enumerate(self.directions):
state = states[i]
out, out_state = direction(input, state)
outputs += [out]
output_states += [out_state]
# tensor array concat assumes axis == 0 for now
# return torch.cat(outputs, -1), output_states
return torch.cat(outputs, 0), output_states
示例29
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output = input
for (i, rnn_layer) in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states
示例30
def forward(self, input, states):
# type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
# List[List[LSTMState]]: The outer list is for layers,
# inner list is for directions.
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
output = input
for (i, rnn_layer) in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states