Python源码示例:ignite.engine.Events.STARTED
示例1
def attach(self, trainer, train_loader):
from torch.utils.data import DataLoader
@trainer.on(Events.STARTED)
def run_benchmark(_):
if idist.get_rank() == 0:
print("-" * 50)
print(" - Dataflow benchmark")
self.benchmark_dataflow.run(train_loader)
t = self.timer.value()
if idist.get_rank() == 0:
print(" ")
print(" Total time ({} iterations) : {:.5f} seconds".format(self.num_iters, t))
print(" time per iteration : {} seconds".format(t / self.num_iters))
if isinstance(train_loader, DataLoader):
num_images = train_loader.batch_size * self.num_iters
print(" number of images / s : {}".format(num_images / t))
print("-" * 50)
示例2
def test_continue_training():
# Tests issue : https://github.com/pytorch/ignite/issues/993
max_epochs = 2
data = range(10)
engine = Engine(lambda e, b: 1)
state = engine.run(data, max_epochs=max_epochs)
assert state.max_epochs == max_epochs
assert state.iteration == len(data) * max_epochs
assert state.epoch == max_epochs
@engine.on(Events.STARTED)
def assert_continue_training():
assert engine.state.epoch == max_epochs
state = engine.run(data, max_epochs=max_epochs * 2)
assert state.max_epochs == max_epochs * 2
assert state.iteration == len(data) * max_epochs * 2
assert state.epoch == max_epochs * 2
示例3
def test_state_custom_attrs_init():
def _test(with_load_state_dict=False):
engine = Engine(lambda e, b: None)
engine.state.alpha = 0.0
engine.state.beta = 1.0
if with_load_state_dict:
engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5})
@engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED)
def check_custom_attr():
assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0
assert hasattr(engine.state, "beta") and engine.state.beta == 1.0
engine.run([0, 1, 2, 3, 4], max_epochs=5)
_test()
_test(with_load_state_dict=True)
示例4
def test_add_event_handler():
engine = DummyEngine()
class Counter(object):
def __init__(self, count=0):
self.count = count
started_counter = Counter()
def handle_iteration_started(engine, counter):
counter.count += 1
engine.add_event_handler(Events.STARTED, handle_iteration_started, started_counter)
completed_counter = Counter()
def handle_iteration_completed(engine, counter):
counter.count += 1
engine.add_event_handler(Events.COMPLETED, handle_iteration_completed, completed_counter)
engine.run(15)
assert started_counter.count == 15
assert completed_counter.count == 15
示例5
def test_add_event_handler_without_engine():
engine = DummyEngine()
class Counter(object):
def __init__(self, count=0):
self.count = count
started_counter = Counter()
def handle_iteration_started():
started_counter.count += 1
engine.add_event_handler(Events.STARTED, handle_iteration_started)
completed_counter = Counter()
def handle_iteration_completed(counter):
counter.count += 1
engine.add_event_handler(Events.COMPLETED, handle_iteration_completed, completed_counter)
engine.run(15)
assert started_counter.count == 15
assert completed_counter.count == 15
示例6
def test_has_event_handler():
engine = DummyEngine()
handlers = [MagicMock(spec_set=True), MagicMock(spec_set=True)]
m = MagicMock(spec_set=True)
for handler in handlers:
engine.add_event_handler(Events.STARTED, handler)
engine.add_event_handler(Events.COMPLETED, m)
for handler in handlers:
assert engine.has_event_handler(handler, Events.STARTED)
assert engine.has_event_handler(handler)
assert not engine.has_event_handler(handler, Events.COMPLETED)
assert not engine.has_event_handler(handler, Events.EPOCH_STARTED)
assert not engine.has_event_handler(m, Events.STARTED)
assert engine.has_event_handler(m, Events.COMPLETED)
assert engine.has_event_handler(m)
assert not engine.has_event_handler(m, Events.EPOCH_STARTED)
示例7
def test_args_and_kwargs_are_passed_to_event():
engine = DummyEngine()
kwargs = {"a": "a", "b": "b"}
args = (1, 2, 3)
handlers = []
for event in [Events.STARTED, Events.COMPLETED]:
handler = create_autospec(spec=lambda e, x1, x2, x3, a, b: None)
engine.add_event_handler(event, handler, *args, **kwargs)
handlers.append(handler)
engine.run(1)
called_handlers = [handle for handle in handlers if handle.called]
assert len(called_handlers) == 2
for handler in called_handlers:
handler_args, handler_kwargs = handler.call_args
assert handler_args[0] == engine
assert handler_args[1::] == args
assert handler_kwargs == kwargs
示例8
def test_on_decorator():
engine = DummyEngine()
class Counter(object):
def __init__(self, count=0):
self.count = count
started_counter = Counter()
@engine.on(Events.STARTED, started_counter)
def handle_iteration_started(engine, started_counter):
started_counter.count += 1
completed_counter = Counter()
@engine.on(Events.COMPLETED, completed_counter)
def handle_iteration_completed(engine, completed_counter):
completed_counter.count += 1
engine.run(15)
assert started_counter.count == 15
assert completed_counter.count == 15
示例9
def test_event_handler_started():
true_event_handler_time = 0.1
true_max_epochs = 2
true_num_iters = 2
profiler = BasicTimeProfiler()
dummy_trainer = Engine(_do_nothing_update_fn)
profiler.attach(dummy_trainer)
@dummy_trainer.on(Events.STARTED)
def delay_start(engine):
time.sleep(true_event_handler_time)
dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
results = profiler.get_results()
event_results = results["event_handlers_stats"]["STARTED"]
assert event_results["total"] == approx(true_event_handler_time, abs=1e-1)
示例10
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
""" Register callbacks to control the timer.
Args:
engine (ignite.engine.Engine):
Engine that this timer will be attached to
start (ignite.engine.Events):
Event which should start (reset) the timer
pause (ignite.engine.Events):
Event which should pause the timer
resume (ignite.engine.Events, optional):
Event which should resume the timer
step (ignite.engine.Events, optional):
Event which should call the `step` method of the counter
Returns:
self (Timer)
"""
engine.add_event_handler(start, self.reset)
engine.add_event_handler(pause, self.pause)
if resume is not None:
engine.add_event_handler(resume, self.resume)
if step is not None:
engine.add_event_handler(step, self.step)
return self
示例11
def test_not_chainerui_logger():
handler = OutputHandler('test', metric_names='all')
engine = MagicMock()
from ignite.contrib.handlers.base_logger import BaseLogger
logger = BaseLogger()
with pytest.raises(RuntimeError) as e:
handler(engine, logger, Events.STARTED)
assert 'only with ChainerUILogger' in str(e.value)
示例12
def test_not_unique_handler(client):
handler1 = OutputHandler('same_name', metric_names='all')
handler2 = OutputHandler('same_name', metric_names='all')
engine = MagicMock()
logger = ChainerUILogger()
logger.attach(engine, handler1, Events.STARTED)
with pytest.raises(RuntimeError) as e:
logger.attach(engine, handler2, Events.STARTED)
assert 'unique tag name' in str(e.value)
示例13
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
""" Register callbacks to control the timer.
Args:
engine (Engine):
Engine that this timer will be attached to.
start (Events):
Event which should start (reset) the timer.
pause (Events):
Event which should pause the timer.
resume (Events, optional):
Event which should resume the timer.
step (Events, optional):
Event which should call the `step` method of the counter.
Returns:
self (Timer)
"""
engine.add_event_handler(start, self.reset)
engine.add_event_handler(pause, self.pause)
if resume is not None:
engine.add_event_handler(resume, self.resume)
if step is not None:
engine.add_event_handler(step, self.step)
return self
示例14
def attach(
self,
engine: Engine,
start: Events = Events.STARTED,
pause: Events = Events.COMPLETED,
resume: Optional[Events] = None,
step: Optional[Events] = None,
):
""" Register callbacks to control the timer.
Args:
engine (Engine):
Engine that this timer will be attached to.
start (Events):
Event which should start (reset) the timer.
pause (Events):
Event which should pause the timer.
resume (Events, optional):
Event which should resume the timer.
step (Events, optional):
Event which should call the `step` method of the counter.
Returns:
self (Timer)
"""
engine.add_event_handler(start, self.reset)
engine.add_event_handler(pause, self.pause)
if resume is not None:
engine.add_event_handler(resume, self.resume)
if step is not None:
engine.add_event_handler(step, self.step)
return self
示例15
def _detach(self, trainer):
"""
Detaches lr_finder from trainer.
Args:
trainer: the trainer to detach form.
"""
if trainer.has_event_handler(self._run, Events.STARTED):
trainer.remove_event_handler(self._run, Events.STARTED)
if trainer.has_event_handler(self._warning, Events.COMPLETED):
trainer.remove_event_handler(self._warning, Events.COMPLETED)
if trainer.has_event_handler(self._reset, Events.COMPLETED):
trainer.remove_event_handler(self._reset, Events.COMPLETED)
示例16
def attach(self, engine):
engine.register_events(*self.Events)
engine.add_event_handler(Events.STARTED, self._on_started)
engine.add_event_handler(
getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started
)
engine.add_event_handler(
getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed
)
示例17
def _reset(self, num_epochs, total_num_iters):
self.dataflow_times = torch.zeros(total_num_iters)
self.processing_times = torch.zeros(total_num_iters)
self.event_handlers_times = {
Events.STARTED: torch.zeros(1),
Events.COMPLETED: torch.zeros(1),
Events.EPOCH_STARTED: torch.zeros(num_epochs),
Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
Events.ITERATION_STARTED: torch.zeros(total_num_iters),
Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
}
示例18
def _as_first_started(self, engine):
if hasattr(engine.state.dataloader, "__len__"):
num_iters_per_epoch = len(engine.state.dataloader)
else:
num_iters_per_epoch = engine.state.epoch_length
self.max_epochs = engine.state.max_epochs
self.total_num_iters = self.max_epochs * num_iters_per_epoch
self._reset(self.max_epochs, self.total_num_iters)
self.event_handlers_names = {
e: [
h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__
for (h, _, _) in engine._event_handlers[e]
if "BasicTimeProfiler." not in repr(h) # avoid adding internal handlers into output
]
for e in Events
if e not in self.events_to_ignore
}
# Setup all other handlers:
engine._event_handlers[Events.STARTED].append((self._as_last_started, (engine,), {}))
for e, m in zip(self._events, self._fmethods):
engine._event_handlers[e].insert(0, (m, (engine,), {}))
for e, m in zip(self._events, self._lmethods):
engine._event_handlers[e].append((m, (engine,), {}))
# Let's go
self._event_handlers_timer.reset()
示例19
def _as_last_completed(self, engine):
self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value()
# Remove added handlers:
engine.remove_event_handler(self._as_last_started, Events.STARTED)
for e, m in zip(self._events, self._fmethods):
engine.remove_event_handler(m, e)
for e, m in zip(self._events, self._lmethods):
engine.remove_event_handler(m, e)
示例20
def attach(self, engine):
if not isinstance(engine, Engine):
raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine)))
if not engine.has_event_handler(self._as_first_started):
engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {}))
示例21
def test_last_event_name():
engine = Engine(MagicMock(return_value=1))
assert engine.last_event_name is None
@engine.on(Events.STARTED)
def _(_engine):
assert _engine.last_event_name == Events.STARTED
@engine.on(Events.EPOCH_STARTED)
def _(_engine):
assert _engine.last_event_name == Events.EPOCH_STARTED
@engine.on(Events.ITERATION_STARTED)
def _(_engine):
assert _engine.last_event_name == Events.ITERATION_STARTED
@engine.on(Events.ITERATION_COMPLETED)
def _(_engine):
assert _engine.last_event_name == Events.ITERATION_COMPLETED
@engine.on(Events.EPOCH_COMPLETED)
def _(_engine):
assert _engine.last_event_name == Events.EPOCH_COMPLETED
engine.run([0, 1])
assert engine.last_event_name == Events.COMPLETED
示例22
def test_state_get_event_attrib_value():
state = State()
state.iteration = 10
state.epoch = 9
e = Events.ITERATION_STARTED
assert state.get_event_attrib_value(e) == state.iteration
e = Events.ITERATION_COMPLETED
assert state.get_event_attrib_value(e) == state.iteration
e = Events.EPOCH_STARTED
assert state.get_event_attrib_value(e) == state.epoch
e = Events.EPOCH_COMPLETED
assert state.get_event_attrib_value(e) == state.epoch
e = Events.STARTED
assert state.get_event_attrib_value(e) == state.epoch
e = Events.COMPLETED
assert state.get_event_attrib_value(e) == state.epoch
e = Events.ITERATION_STARTED(every=10)
assert state.get_event_attrib_value(e) == state.iteration
e = Events.ITERATION_COMPLETED(every=10)
assert state.get_event_attrib_value(e) == state.iteration
e = Events.EPOCH_STARTED(once=5)
assert state.get_event_attrib_value(e) == state.epoch
e = Events.EPOCH_COMPLETED(once=5)
assert state.get_event_attrib_value(e) == state.epoch
示例23
def _test_check_triggered_events(data, max_epochs, epoch_length, exp_iter_stops=None):
engine = Engine(lambda e, b: 1)
events = [
Events.STARTED,
Events.EPOCH_STARTED,
Events.ITERATION_STARTED,
Events.ITERATION_COMPLETED,
Events.EPOCH_COMPLETED,
Events.COMPLETED,
Events.GET_BATCH_STARTED,
Events.GET_BATCH_COMPLETED,
Events.DATALOADER_STOP_ITERATION,
]
handlers = {e: MagicMock() for e in events}
for e, handler in handlers.items():
engine.add_event_handler(e, handler)
engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)
expected_num_calls = {
Events.STARTED: 1,
Events.COMPLETED: 1,
Events.EPOCH_STARTED: max_epochs,
Events.EPOCH_COMPLETED: max_epochs,
Events.ITERATION_STARTED: max_epochs * epoch_length,
Events.ITERATION_COMPLETED: max_epochs * epoch_length,
Events.GET_BATCH_STARTED: max_epochs * epoch_length,
Events.GET_BATCH_COMPLETED: max_epochs * epoch_length,
Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops,
}
for n, handler in handlers.items():
assert handler.call_count == expected_num_calls[n], "{}: {} vs {}".format(
n, handler.call_count, expected_num_calls[n]
)
示例24
def test_custom_events_with_events_list():
class CustomEvents(EventEnum):
TEST_EVENT = "test_event"
def process_func(engine, batch):
engine.fire_event(CustomEvents.TEST_EVENT)
engine = Engine(process_func)
engine.register_events(*CustomEvents)
# Handle should be called
handle = MagicMock()
engine.add_event_handler(CustomEvents.TEST_EVENT | Events.STARTED, handle)
engine.run(range(1))
assert handle.called
示例25
def run(self, num_times):
self.state = State()
for _ in range(num_times):
self.fire_event(Events.STARTED)
self.fire_event(Events.COMPLETED)
return self.state
示例26
def test_add_event_handler_raises_with_invalid_signature():
engine = Engine(MagicMock())
def handler(engine):
pass
engine.add_event_handler(Events.STARTED, handler)
engine.add_event_handler(Events.STARTED, handler, 1)
def handler_with_args(engine, a):
pass
engine.add_event_handler(Events.STARTED, handler_with_args, 1)
with pytest.raises(ValueError):
engine.add_event_handler(Events.STARTED, handler_with_args)
def handler_with_kwargs(engine, b=42):
pass
engine.add_event_handler(Events.STARTED, handler_with_kwargs, b=2)
with pytest.raises(ValueError):
engine.add_event_handler(Events.STARTED, handler_with_kwargs, c=3)
engine.add_event_handler(Events.STARTED, handler_with_kwargs, 1, b=2)
def handler_with_args_and_kwargs(engine, a, b=42):
pass
engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, b=2)
engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, 2, b=2)
with pytest.raises(ValueError):
engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, b=2, c=3)
示例27
def test_remove_event_handler():
engine = DummyEngine()
with pytest.raises(ValueError, match=r"Input event name"):
engine.remove_event_handler(lambda x: x, "an event")
def on_started(engine):
return 0
engine.add_event_handler(Events.STARTED, on_started)
with pytest.raises(ValueError, match=r"Input handler"):
engine.remove_event_handler(lambda x: x, Events.STARTED)
h1 = MagicMock(spec_set=True)
h2 = MagicMock(spec_set=True)
handlers = [h1, h2]
m = MagicMock(spec_set=True)
for handler in handlers:
engine.add_event_handler(Events.EPOCH_STARTED, handler)
engine.add_event_handler(Events.EPOCH_COMPLETED, m)
assert len(engine._event_handlers[Events.EPOCH_STARTED]) == 2
engine.remove_event_handler(h1, Events.EPOCH_STARTED)
assert len(engine._event_handlers[Events.EPOCH_STARTED]) == 1
assert engine._event_handlers[Events.EPOCH_STARTED][0][0] == h2
assert len(engine._event_handlers[Events.EPOCH_COMPLETED]) == 1
engine.remove_event_handler(m, Events.EPOCH_COMPLETED)
assert len(engine._event_handlers[Events.EPOCH_COMPLETED]) == 0
示例28
def test_attach():
n_epochs = 5
data = list(range(50))
def _test(event, n_calls):
losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)
def update_fn(engine, batch):
return next(losses_iter)
trainer = Engine(update_fn)
logger = DummyLogger()
mock_log_handler = MagicMock()
logger.attach(trainer, log_handler=mock_log_handler, event_name=event)
trainer.run(data, max_epochs=n_epochs)
mock_log_handler.assert_called_with(trainer, logger, event)
assert mock_log_handler.call_count == n_calls
_test(Events.ITERATION_STARTED, len(data) * n_epochs)
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
_test(Events.EPOCH_STARTED, n_epochs)
_test(Events.EPOCH_COMPLETED, n_epochs)
_test(Events.STARTED, 1)
_test(Events.COMPLETED, 1)
_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
示例29
def get_prepared_engine(true_event_handler_time):
dummy_trainer = Engine(_do_nothing_update_fn)
@dummy_trainer.on(Events.STARTED)
def delay_start(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.COMPLETED)
def delay_complete(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.EPOCH_STARTED)
def delay_epoch_start(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.EPOCH_COMPLETED)
def delay_epoch_complete(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.ITERATION_STARTED)
def delay_iter_start(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.ITERATION_COMPLETED)
def delay_iter_complete(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.GET_BATCH_STARTED)
def delay_get_batch_started(engine):
time.sleep(true_event_handler_time)
@dummy_trainer.on(Events.GET_BATCH_COMPLETED)
def delay_get_batch_completed(engine):
time.sleep(true_event_handler_time)
return dummy_trainer
示例30
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
""" Register callbacks to control the timer.
Args:
engine (ignite.engine.Engine):
Engine that this timer will be attached to
start (ignite.engine.Events):
Event which should start (reset) the timer
pause (ignite.engine.Events):
Event which should pause the timer
resume (ignite.engine.Events, optional):
Event which should resume the timer
step (ignite.engine.Events, optional):
Event which should call the `step` method of the counter
Returns:
self (Timer)
"""
engine.add_event_handler(start, self.reset)
engine.add_event_handler(pause, self.pause)
if resume is not None:
engine.add_event_handler(resume, self.resume)
if step is not None:
engine.add_event_handler(step, self.step)
return self