Python源码示例:ignite.engine.Events.ITERATION_STARTED
示例1
def test_terminate_epoch_stops_mid_epoch():
num_iterations_per_epoch = 10
iteration_to_stop = num_iterations_per_epoch + 4
engine = Engine(MagicMock(return_value=1))
def start_of_iteration_handler(engine):
if engine.state.iteration == iteration_to_stop:
engine.terminate_epoch()
max_epochs = 3
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=max_epochs)
# completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
true_value = num_iterations_per_epoch * (max_epochs - 1) + iteration_to_stop % num_iterations_per_epoch
assert state.iteration == true_value
示例2
def test_callable_events_with_wrong_inputs():
with pytest.raises(ValueError, match=r"Only one of the input arguments should be specified"):
Events.ITERATION_STARTED()
with pytest.raises(ValueError, match=r"Only one of the input arguments should be specified"):
Events.ITERATION_STARTED(event_filter="123", every=12)
with pytest.raises(TypeError, match=r"Argument event_filter should be a callable"):
Events.ITERATION_STARTED(event_filter="123")
with pytest.raises(ValueError, match=r"Argument every should be integer and greater than zero"):
Events.ITERATION_STARTED(every=-1)
with pytest.raises(ValueError, match=r"but will be called with"):
Events.ITERATION_STARTED(event_filter=lambda x: x)
示例3
def test_list_of_events():
def _test(event_list, true_iterations):
engine = Engine(lambda e, b: b)
iterations = []
num_calls = [0]
@engine.on(event_list)
def execute_some_handler(e):
iterations.append(e.state.iteration)
num_calls[0] += 1
engine.run(range(3), max_epochs=5)
assert iterations == true_iterations
assert num_calls[0] == len(true_iterations)
_test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(once=1), [1, 1])
_test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(once=10), [1, 10])
_test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(every=3), [1, 3, 6, 9, 12, 15])
示例4
def test_optimizer_params():
optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
mock_logger = MagicMock(spec=TensorboardLogger)
mock_logger.writer = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.writer.add_scalar.assert_called_once_with("lr/group_0", 0.01, 123)
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=TensorboardLogger)
mock_logger.writer = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.writer.add_scalar.assert_called_once_with("generator/lr/group_0", 0.01, 123)
示例5
def test_weights_scalar_handler_wrong_setup():
with pytest.raises(TypeError, match="Argument model should be of type torch.nn.Module"):
WeightsScalarHandler(None)
model = MagicMock(spec=torch.nn.Module)
with pytest.raises(TypeError, match="Argument reduction should be callable"):
WeightsScalarHandler(model, reduction=123)
with pytest.raises(ValueError, match="Output of the reduction function should be a scalar"):
WeightsScalarHandler(model, reduction=lambda x: x)
wrapper = WeightsScalarHandler(model)
mock_logger = MagicMock()
mock_engine = MagicMock()
with pytest.raises(RuntimeError, match="Handler 'WeightsScalarHandler' works only with TensorboardLogger"):
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
示例6
def test_weights_scalar_handler_wrong_setup():
with pytest.raises(TypeError, match="Argument model should be of type torch.nn.Module"):
WeightsScalarHandler(None)
model = MagicMock(spec=torch.nn.Module)
with pytest.raises(TypeError, match="Argument reduction should be callable"):
WeightsScalarHandler(model, reduction=123)
with pytest.raises(ValueError, match="Output of the reduction function should be a scalar"):
WeightsScalarHandler(model, reduction=lambda x: x)
wrapper = WeightsScalarHandler(model)
mock_logger = MagicMock()
mock_engine = MagicMock()
with pytest.raises(RuntimeError, match="Handler 'WeightsScalarHandler' works only with VisdomLogger"):
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
示例7
def test_optimizer_params():
optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
mock_logger = MagicMock(spec=TrainsLogger)
mock_logger.trains_logger = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.trains_logger.report_scalar.assert_called_once_with(iteration=123, series="0", title="lr", value=0.01)
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=TrainsLogger)
mock_logger.trains_logger = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.trains_logger.report_scalar.assert_called_once_with(
iteration=123, series="0", title="generator/lr", value=0.01
)
示例8
def test_output_handler_output_transform(dirname):
wrapper = OutputHandler("tag", output_transform=lambda x: x)
mock_logger = MagicMock(spec=TrainsLogger)
mock_logger.trains_logger = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.trains_logger.report_scalar.assert_called_once_with(
iteration=123, series="output", title="tag", value=12345
)
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
mock_logger = MagicMock(spec=TrainsLogger)
mock_logger.trains_logger = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.trains_logger.report_scalar.assert_called_once_with(
iteration=123, series="loss", title="another_tag", value=12345
)
示例9
def test_output_handler_output_transform():
wrapper = OutputHandler("tag", output_transform=lambda x: x)
mock_logger = MagicMock(spec=PolyaxonLogger)
mock_logger.log_metrics = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with(step=123, **{"tag/output": 12345})
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
mock_logger = MagicMock(spec=PolyaxonLogger)
mock_logger.log_metrics = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with(step=123, **{"another_tag/loss": 12345})
示例10
def test_optimizer_params():
optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
mock_logger = MagicMock(spec=PolyaxonLogger)
mock_logger.log_metrics = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with(**{"lr/group_0": 0.01, "step": 123})
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=PolyaxonLogger)
mock_logger.log_metrics = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with(**{"generator/lr/group_0": 0.01, "step": 123})
示例11
def test_pbar_batch_indeces(capsys):
engine = Engine(lambda e, b: time.sleep(0.1))
@engine.on(Events.ITERATION_STARTED)
def print_iter(_):
print("iteration: ", engine.state.iteration)
ProgressBar(persist=True).attach(engine)
engine.run(list(range(4)), max_epochs=1)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
printed_batch_indeces = set(map(lambda x: int(x.split("/")[0][-1]), err))
expected_batch_indeces = list(range(1, 5))
assert sorted(list(printed_batch_indeces)) == expected_batch_indeces
示例12
def test_pbar_wrong_events_order():
engine = Engine(update_fn)
pbar = ProgressBar()
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)
with pytest.raises(ValueError, match="should not be a filtered event"):
pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))
示例13
def test_pbar_on_callable_events(capsys):
n_epochs = 1
loader = list(range(100))
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, event_name=Events.ITERATION_STARTED(every=10), closing_event_name=Events.EPOCH_COMPLETED)
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch: [90/100] 90%|█████████ [00:00<00:00]"
assert actual == expected
示例14
def test_optimizer_params():
optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_metric = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metric.assert_called_once_with("lr/group_0", y=0.01, x=123)
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_metric = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metric.assert_called_once_with("generator/lr/group_0", y=0.01, x=123)
示例15
def test_output_handler_output_transform():
wrapper = OutputHandler("tag", output_transform=lambda x: x)
mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_metric = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metric.assert_called_once_with("tag/output", y=12345, x=123)
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
mock_logger = MagicMock(spec=NeptuneLogger)
mock_logger.log_metric = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metric.assert_called_once_with("another_tag/loss", y=12345, x=123)
示例16
def test_weights_scalar_handler_wrong_setup():
with pytest.raises(TypeError, match="Argument model should be of type torch.nn.Module"):
WeightsScalarHandler(None)
model = MagicMock(spec=torch.nn.Module)
with pytest.raises(TypeError, match="Argument reduction should be callable"):
WeightsScalarHandler(model, reduction=123)
with pytest.raises(ValueError, match="Output of the reduction function should be a scalar"):
WeightsScalarHandler(model, reduction=lambda x: x)
wrapper = WeightsScalarHandler(model)
mock_logger = MagicMock()
mock_engine = MagicMock()
with pytest.raises(RuntimeError, match="Handler WeightsScalarHandler works only with NeptuneLogger"):
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
示例17
def test_event_handler_iteration_started():
true_event_handler_time = 0.1
true_max_epochs = 1
true_num_iters = 2
profiler = BasicTimeProfiler()
dummy_trainer = Engine(_do_nothing_update_fn)
profiler.attach(dummy_trainer)
@dummy_trainer.on(Events.ITERATION_STARTED)
def delay_iter_start(engine):
time.sleep(true_event_handler_time)
dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
results = profiler.get_results()
event_results = results["event_handlers_stats"]["ITERATION_STARTED"]
assert event_results["min/index"][0] == approx(true_event_handler_time, abs=1e-1)
assert event_results["max/index"][0] == approx(true_event_handler_time, abs=1e-1)
assert event_results["mean"] == approx(true_event_handler_time, abs=1e-1)
assert event_results["std"] == approx(0.0, abs=1e-1)
assert event_results["total"] == approx(true_max_epochs * true_num_iters * true_event_handler_time, abs=1e-1)
示例18
def test_output_handler_output_transform():
wrapper = OutputHandler("tag", output_transform=lambda x: x)
mock_logger = MagicMock(spec=MLflowLogger)
mock_logger.log_metrics = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with({"tag output": 12345}, step=123)
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
mock_logger = MagicMock(spec=MLflowLogger)
mock_logger.log_metrics = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with(
{"another_tag loss": 12345}, step=123,
)
示例19
def test_optimizer_params():
optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
mock_logger = MagicMock(spec=MLflowLogger)
mock_logger.log_metrics = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with({"lr group_0": 0.01}, step=123)
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=MLflowLogger)
mock_logger.log_metrics = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log_metrics.assert_called_once_with({"generator lr group_0": 0.01}, step=123)
示例20
def test_save_param_history():
tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0)
scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, save_history=True)
lrs = []
def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])
trainer = Engine(lambda engine, batch: None)
assert not hasattr(trainer.state, "param_history")
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
trainer.run([0] * 10, max_epochs=2)
state_lrs = trainer.state.param_history["lr"]
assert len(state_lrs) == len(lrs)
# Unpack singleton lists
assert [group[0] for group in state_lrs] == lrs
示例21
def test_optimizer_params():
optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123, sync=None)
wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123, sync=None)
示例22
def test_output_handler_output_transform():
wrapper = OutputHandler("tag", output_transform=lambda x: x)
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=None)
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=None)
示例23
def test_output_handler_output_transform_sync():
wrapper = OutputHandler("tag", output_transform=lambda x: x, sync=False)
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()
mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.output = 12345
mock_engine.state.iteration = 123
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=False)
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}, sync=True)
mock_logger = MagicMock(spec=WandBLogger)
mock_logger.log = MagicMock()
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=True)
示例24
def discriminative_learning(self):
logger.info("Using discriminative learning as adaptation strategy")
# Build parameters groups by layer, numbered from the top ['1', '2', ..., '15']
parameter_groups = []
for i in range(self.model.num_layers):
name_pattern = r"transformer\.[^\.]*\." + str(i) + r"\."
group = {
'name': str(self.model.num_layers - i),
'params': [p for n, p in self.model.named_parameters() if re.match(name_pattern, n)]}
parameter_groups.append(group)
# Add the rest of the parameters (embeddings and classification layer) in a group labeled '0'
name_pattern = r"transformer\.[^\.]*\.\d*\."
group = {
'name': '0',
'params': [p for n, p in self.model.named_parameters() if not re.match(name_pattern, n)]}
parameter_groups.append(group)
# Sanity check that we still have the same number of parameters
assert sum(p.numel() for g in parameter_groups for p in g['params']) \
== sum(p.numel() for p in self.model.parameters())
@self.trainer.on(Events.ITERATION_STARTED)
def update_layer_learning_rates(engine):
for param_group in self.optimizer.param_groups:
layer_index = int(param_group["name"])
param_group["lr"] = param_group["lr"] / (self.decreasing_factor ** layer_index)
return parameter_groups
示例25
def __init__(self):
super(BatchWise, self).__init__(
started=Events.ITERATION_STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)
示例26
def _setup_logging(logger, trainer, optimizers, evaluators, log_every_iters):
if optimizers is not None:
from torch.optim.optimizer import Optimizer
if not isinstance(optimizers, (Optimizer, Mapping)):
raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers")
if evaluators is not None:
if not isinstance(evaluators, (Engine, Mapping)):
raise TypeError("Argument evaluators should be either a single engine or a dictionary or engines")
if log_every_iters is None:
log_every_iters = 1
logger.attach_output_handler(
trainer, event_name=Events.ITERATION_COMPLETED(every=log_every_iters), tag="training", metric_names="all"
)
if optimizers is not None:
# Log optimizer parameters
if isinstance(optimizers, Optimizer):
optimizers = {None: optimizers}
for k, optimizer in optimizers.items():
logger.attach_opt_params_handler(
trainer, Events.ITERATION_STARTED(every=log_every_iters), optimizer, param_name="lr", tag=k
)
if evaluators is not None:
# Log evaluation metrics
if isinstance(evaluators, Engine):
evaluators = {"validation": evaluators}
event_name = Events.ITERATION_COMPLETED if isinstance(logger, WandBLogger) else None
gst = global_step_from_engine(trainer, custom_event_name=event_name)
for k, evaluator in evaluators.items():
logger.attach_output_handler(
evaluator, event_name=Events.COMPLETED, tag=k, metric_names="all", global_step_transform=gst
)
示例27
def __init__(self):
self._dataflow_timer = Timer()
self._processing_timer = Timer()
self._event_handlers_timer = Timer()
self.dataflow_times = None
self.processing_times = None
self.event_handlers_times = None
self._events = [
Events.EPOCH_STARTED,
Events.EPOCH_COMPLETED,
Events.ITERATION_STARTED,
Events.ITERATION_COMPLETED,
Events.GET_BATCH_STARTED,
Events.GET_BATCH_COMPLETED,
Events.COMPLETED,
]
self._fmethods = [
self._as_first_epoch_started,
self._as_first_epoch_completed,
self._as_first_iter_started,
self._as_first_iter_completed,
self._as_first_get_batch_started,
self._as_first_get_batch_completed,
self._as_first_completed,
]
self._lmethods = [
self._as_last_epoch_started,
self._as_last_epoch_completed,
self._as_last_iter_started,
self._as_last_iter_completed,
self._as_last_get_batch_started,
self._as_last_get_batch_completed,
self._as_last_completed,
]
示例28
def _reset(self, num_epochs, total_num_iters):
self.dataflow_times = torch.zeros(total_num_iters)
self.processing_times = torch.zeros(total_num_iters)
self.event_handlers_times = {
Events.STARTED: torch.zeros(1),
Events.COMPLETED: torch.zeros(1),
Events.EPOCH_STARTED: torch.zeros(num_epochs),
Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
Events.ITERATION_STARTED: torch.zeros(total_num_iters),
Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
}
示例29
def get_max_number_events(event_name, engine):
if event_name in (Events.ITERATION_STARTED, Events.ITERATION_COMPLETED):
return engine.state.epoch_length
if event_name in (Events.EPOCH_STARTED, Events.EPOCH_COMPLETED):
return engine.state.max_epochs
return 1
示例30
def __init__(self, num_iters=100, prepare_batch=None):
from ignite.handlers import Timer
device = idist.device()
def upload_to_gpu(engine, batch):
if prepare_batch is not None:
x, y = prepare_batch(batch, device=device, non_blocking=False)
self.num_iters = num_iters
self.benchmark_dataflow = Engine(upload_to_gpu)
@self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters))
def stop_benchmark_dataflow(engine):
engine.terminate()
if idist.get_rank() == 0:
@self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(every=num_iters // 100))
def show_progress_benchmark_dataflow(engine):
print(".", end=" ")
self.timer = Timer(average=False)
self.timer.attach(
self.benchmark_dataflow,
start=Events.EPOCH_STARTED,
resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED,
step=Events.ITERATION_COMPLETED,
)