Python源码示例:ignite.engine.Events.STARTED

示例1
def attach(self, trainer, train_loader):

        from torch.utils.data import DataLoader

        @trainer.on(Events.STARTED)
        def run_benchmark(_):
            if idist.get_rank() == 0:
                print("-" * 50)
                print(" - Dataflow benchmark")

            self.benchmark_dataflow.run(train_loader)
            t = self.timer.value()

            if idist.get_rank() == 0:
                print(" ")
                print(" Total time ({} iterations) : {:.5f} seconds".format(self.num_iters, t))
                print(" time per iteration         : {} seconds".format(t / self.num_iters))

                if isinstance(train_loader, DataLoader):
                    num_images = train_loader.batch_size * self.num_iters
                    print(" number of images / s       : {}".format(num_images / t))

                print("-" * 50) 
示例2
def test_continue_training():
    # Tests issue : https://github.com/pytorch/ignite/issues/993
    max_epochs = 2
    data = range(10)
    engine = Engine(lambda e, b: 1)
    state = engine.run(data, max_epochs=max_epochs)
    assert state.max_epochs == max_epochs
    assert state.iteration == len(data) * max_epochs
    assert state.epoch == max_epochs

    @engine.on(Events.STARTED)
    def assert_continue_training():
        assert engine.state.epoch == max_epochs

    state = engine.run(data, max_epochs=max_epochs * 2)
    assert state.max_epochs == max_epochs * 2
    assert state.iteration == len(data) * max_epochs * 2
    assert state.epoch == max_epochs * 2 
示例3
def test_state_custom_attrs_init():
    def _test(with_load_state_dict=False):
        engine = Engine(lambda e, b: None)
        engine.state.alpha = 0.0
        engine.state.beta = 1.0

        if with_load_state_dict:
            engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5})

        @engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED)
        def check_custom_attr():
            assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0
            assert hasattr(engine.state, "beta") and engine.state.beta == 1.0

        engine.run([0, 1, 2, 3, 4], max_epochs=5)

    _test()
    _test(with_load_state_dict=True) 
示例4
def test_add_event_handler():
    engine = DummyEngine()

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    started_counter = Counter()

    def handle_iteration_started(engine, counter):
        counter.count += 1

    engine.add_event_handler(Events.STARTED, handle_iteration_started, started_counter)

    completed_counter = Counter()

    def handle_iteration_completed(engine, counter):
        counter.count += 1

    engine.add_event_handler(Events.COMPLETED, handle_iteration_completed, completed_counter)

    engine.run(15)

    assert started_counter.count == 15
    assert completed_counter.count == 15 
示例5
def test_add_event_handler_without_engine():
    engine = DummyEngine()

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    started_counter = Counter()

    def handle_iteration_started():
        started_counter.count += 1

    engine.add_event_handler(Events.STARTED, handle_iteration_started)

    completed_counter = Counter()

    def handle_iteration_completed(counter):
        counter.count += 1

    engine.add_event_handler(Events.COMPLETED, handle_iteration_completed, completed_counter)

    engine.run(15)

    assert started_counter.count == 15
    assert completed_counter.count == 15 
示例6
def test_has_event_handler():
    engine = DummyEngine()
    handlers = [MagicMock(spec_set=True), MagicMock(spec_set=True)]
    m = MagicMock(spec_set=True)
    for handler in handlers:
        engine.add_event_handler(Events.STARTED, handler)
    engine.add_event_handler(Events.COMPLETED, m)

    for handler in handlers:
        assert engine.has_event_handler(handler, Events.STARTED)
        assert engine.has_event_handler(handler)
        assert not engine.has_event_handler(handler, Events.COMPLETED)
        assert not engine.has_event_handler(handler, Events.EPOCH_STARTED)

    assert not engine.has_event_handler(m, Events.STARTED)
    assert engine.has_event_handler(m, Events.COMPLETED)
    assert engine.has_event_handler(m)
    assert not engine.has_event_handler(m, Events.EPOCH_STARTED) 
示例7
def test_args_and_kwargs_are_passed_to_event():
    engine = DummyEngine()
    kwargs = {"a": "a", "b": "b"}
    args = (1, 2, 3)
    handlers = []
    for event in [Events.STARTED, Events.COMPLETED]:
        handler = create_autospec(spec=lambda e, x1, x2, x3, a, b: None)
        engine.add_event_handler(event, handler, *args, **kwargs)
        handlers.append(handler)

    engine.run(1)
    called_handlers = [handle for handle in handlers if handle.called]
    assert len(called_handlers) == 2

    for handler in called_handlers:
        handler_args, handler_kwargs = handler.call_args
        assert handler_args[0] == engine
        assert handler_args[1::] == args
        assert handler_kwargs == kwargs 
示例8
def test_on_decorator():
    engine = DummyEngine()

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    started_counter = Counter()

    @engine.on(Events.STARTED, started_counter)
    def handle_iteration_started(engine, started_counter):
        started_counter.count += 1

    completed_counter = Counter()

    @engine.on(Events.COMPLETED, completed_counter)
    def handle_iteration_completed(engine, completed_counter):
        completed_counter.count += 1

    engine.run(15)

    assert started_counter.count == 15
    assert completed_counter.count == 15 
示例9
def test_event_handler_started():
    true_event_handler_time = 0.1
    true_max_epochs = 2
    true_num_iters = 2

    profiler = BasicTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.STARTED)
    def delay_start(engine):
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = results["event_handlers_stats"]["STARTED"]

    assert event_results["total"] == approx(true_event_handler_time, abs=1e-1) 
示例10
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
        """ Register callbacks to control the timer.

        Args:
            engine (ignite.engine.Engine):
                Engine that this timer will be attached to
            start (ignite.engine.Events):
                Event which should start (reset) the timer
            pause (ignite.engine.Events):
                Event which should pause the timer
            resume (ignite.engine.Events, optional):
                Event which should resume the timer
            step (ignite.engine.Events, optional):
                Event which should call the `step` method of the counter

        Returns:
            self (Timer)

        """

        engine.add_event_handler(start, self.reset)
        engine.add_event_handler(pause, self.pause)

        if resume is not None:
            engine.add_event_handler(resume, self.resume)

        if step is not None:
            engine.add_event_handler(step, self.step)

        return self 
示例11
def test_not_chainerui_logger():
    handler = OutputHandler('test', metric_names='all')

    engine = MagicMock()

    from ignite.contrib.handlers.base_logger import BaseLogger
    logger = BaseLogger()

    with pytest.raises(RuntimeError) as e:
        handler(engine, logger, Events.STARTED)
    assert 'only with ChainerUILogger' in str(e.value) 
示例12
def test_not_unique_handler(client):
    handler1 = OutputHandler('same_name', metric_names='all')
    handler2 = OutputHandler('same_name', metric_names='all')

    engine = MagicMock()

    logger = ChainerUILogger()
    logger.attach(engine, handler1, Events.STARTED)
    with pytest.raises(RuntimeError) as e:
        logger.attach(engine, handler2, Events.STARTED)
    assert 'unique tag name' in str(e.value) 
示例13
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
        """ Register callbacks to control the timer.

        Args:
            engine (Engine):
                Engine that this timer will be attached to.
            start (Events):
                Event which should start (reset) the timer.
            pause (Events):
                Event which should pause the timer.
            resume (Events, optional):
                Event which should resume the timer.
            step (Events, optional):
                Event which should call the `step` method of the counter.

        Returns:
            self (Timer)

        """

        engine.add_event_handler(start, self.reset)
        engine.add_event_handler(pause, self.pause)

        if resume is not None:
            engine.add_event_handler(resume, self.resume)

        if step is not None:
            engine.add_event_handler(step, self.step)

        return self 
示例14
def attach(
        self,
        engine: Engine,
        start: Events = Events.STARTED,
        pause: Events = Events.COMPLETED,
        resume: Optional[Events] = None,
        step: Optional[Events] = None,
    ):
        """ Register callbacks to control the timer.

        Args:
            engine (Engine):
                Engine that this timer will be attached to.
            start (Events):
                Event which should start (reset) the timer.
            pause (Events):
                Event which should pause the timer.
            resume (Events, optional):
                Event which should resume the timer.
            step (Events, optional):
                Event which should call the `step` method of the counter.

        Returns:
            self (Timer)

        """

        engine.add_event_handler(start, self.reset)
        engine.add_event_handler(pause, self.pause)

        if resume is not None:
            engine.add_event_handler(resume, self.resume)

        if step is not None:
            engine.add_event_handler(step, self.step)

        return self 
示例15
def _detach(self, trainer):
        """
        Detaches lr_finder from trainer.

        Args:
            trainer: the trainer to detach form.
        """

        if trainer.has_event_handler(self._run, Events.STARTED):
            trainer.remove_event_handler(self._run, Events.STARTED)
        if trainer.has_event_handler(self._warning, Events.COMPLETED):
            trainer.remove_event_handler(self._warning, Events.COMPLETED)
        if trainer.has_event_handler(self._reset, Events.COMPLETED):
            trainer.remove_event_handler(self._reset, Events.COMPLETED) 
示例16
def attach(self, engine):
        engine.register_events(*self.Events)

        engine.add_event_handler(Events.STARTED, self._on_started)
        engine.add_event_handler(
            getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started
        )
        engine.add_event_handler(
            getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed
        ) 
示例17
def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
        } 
示例18
def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
                if "BasicTimeProfiler." not in repr(h)  # avoid adding internal handlers into output
            ]
            for e in Events
            if e not in self.events_to_ignore
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append((self._as_last_started, (engine,), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (engine,), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (engine,), {}))

        # Let's go
        self._event_handlers_timer.reset() 
示例19
def _as_last_completed(self, engine):
        self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e) 
示例20
def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {})) 
示例21
def test_last_event_name():
    engine = Engine(MagicMock(return_value=1))
    assert engine.last_event_name is None

    @engine.on(Events.STARTED)
    def _(_engine):
        assert _engine.last_event_name == Events.STARTED

    @engine.on(Events.EPOCH_STARTED)
    def _(_engine):
        assert _engine.last_event_name == Events.EPOCH_STARTED

    @engine.on(Events.ITERATION_STARTED)
    def _(_engine):
        assert _engine.last_event_name == Events.ITERATION_STARTED

    @engine.on(Events.ITERATION_COMPLETED)
    def _(_engine):
        assert _engine.last_event_name == Events.ITERATION_COMPLETED

    @engine.on(Events.EPOCH_COMPLETED)
    def _(_engine):
        assert _engine.last_event_name == Events.EPOCH_COMPLETED

    engine.run([0, 1])
    assert engine.last_event_name == Events.COMPLETED 
示例22
def test_state_get_event_attrib_value():
    state = State()
    state.iteration = 10
    state.epoch = 9

    e = Events.ITERATION_STARTED
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.ITERATION_COMPLETED
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.EPOCH_STARTED
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.EPOCH_COMPLETED
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.STARTED
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.COMPLETED
    assert state.get_event_attrib_value(e) == state.epoch

    e = Events.ITERATION_STARTED(every=10)
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.ITERATION_COMPLETED(every=10)
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.EPOCH_STARTED(once=5)
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.EPOCH_COMPLETED(once=5)
    assert state.get_event_attrib_value(e) == state.epoch 
示例23
def _test_check_triggered_events(data, max_epochs, epoch_length, exp_iter_stops=None):
    engine = Engine(lambda e, b: 1)
    events = [
        Events.STARTED,
        Events.EPOCH_STARTED,
        Events.ITERATION_STARTED,
        Events.ITERATION_COMPLETED,
        Events.EPOCH_COMPLETED,
        Events.COMPLETED,
        Events.GET_BATCH_STARTED,
        Events.GET_BATCH_COMPLETED,
        Events.DATALOADER_STOP_ITERATION,
    ]

    handlers = {e: MagicMock() for e in events}

    for e, handler in handlers.items():
        engine.add_event_handler(e, handler)

    engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

    expected_num_calls = {
        Events.STARTED: 1,
        Events.COMPLETED: 1,
        Events.EPOCH_STARTED: max_epochs,
        Events.EPOCH_COMPLETED: max_epochs,
        Events.ITERATION_STARTED: max_epochs * epoch_length,
        Events.ITERATION_COMPLETED: max_epochs * epoch_length,
        Events.GET_BATCH_STARTED: max_epochs * epoch_length,
        Events.GET_BATCH_COMPLETED: max_epochs * epoch_length,
        Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops,
    }

    for n, handler in handlers.items():
        assert handler.call_count == expected_num_calls[n], "{}: {} vs {}".format(
            n, handler.call_count, expected_num_calls[n]
        ) 
示例24
def test_custom_events_with_events_list():
    class CustomEvents(EventEnum):
        TEST_EVENT = "test_event"

    def process_func(engine, batch):
        engine.fire_event(CustomEvents.TEST_EVENT)

    engine = Engine(process_func)
    engine.register_events(*CustomEvents)

    # Handle should be called
    handle = MagicMock()
    engine.add_event_handler(CustomEvents.TEST_EVENT | Events.STARTED, handle)
    engine.run(range(1))
    assert handle.called 
示例25
def run(self, num_times):
        self.state = State()
        for _ in range(num_times):
            self.fire_event(Events.STARTED)
            self.fire_event(Events.COMPLETED)
        return self.state 
示例26
def test_add_event_handler_raises_with_invalid_signature():
    engine = Engine(MagicMock())

    def handler(engine):
        pass

    engine.add_event_handler(Events.STARTED, handler)
    engine.add_event_handler(Events.STARTED, handler, 1)

    def handler_with_args(engine, a):
        pass

    engine.add_event_handler(Events.STARTED, handler_with_args, 1)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler_with_args)

    def handler_with_kwargs(engine, b=42):
        pass

    engine.add_event_handler(Events.STARTED, handler_with_kwargs, b=2)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler_with_kwargs, c=3)
    engine.add_event_handler(Events.STARTED, handler_with_kwargs, 1, b=2)

    def handler_with_args_and_kwargs(engine, a, b=42):
        pass

    engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, b=2)
    engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, 2, b=2)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, b=2, c=3) 
示例27
def test_remove_event_handler():
    engine = DummyEngine()

    with pytest.raises(ValueError, match=r"Input event name"):
        engine.remove_event_handler(lambda x: x, "an event")

    def on_started(engine):
        return 0

    engine.add_event_handler(Events.STARTED, on_started)

    with pytest.raises(ValueError, match=r"Input handler"):
        engine.remove_event_handler(lambda x: x, Events.STARTED)

    h1 = MagicMock(spec_set=True)
    h2 = MagicMock(spec_set=True)
    handlers = [h1, h2]
    m = MagicMock(spec_set=True)
    for handler in handlers:
        engine.add_event_handler(Events.EPOCH_STARTED, handler)
    engine.add_event_handler(Events.EPOCH_COMPLETED, m)

    assert len(engine._event_handlers[Events.EPOCH_STARTED]) == 2
    engine.remove_event_handler(h1, Events.EPOCH_STARTED)
    assert len(engine._event_handlers[Events.EPOCH_STARTED]) == 1
    assert engine._event_handlers[Events.EPOCH_STARTED][0][0] == h2

    assert len(engine._event_handlers[Events.EPOCH_COMPLETED]) == 1
    engine.remove_event_handler(m, Events.EPOCH_COMPLETED)
    assert len(engine._event_handlers[Events.EPOCH_COMPLETED]) == 0 
示例28
def test_attach():

    n_epochs = 5
    data = list(range(50))

    def _test(event, n_calls):

        losses = torch.rand(n_epochs * len(data))
        losses_iter = iter(losses)

        def update_fn(engine, batch):
            return next(losses_iter)

        trainer = Engine(update_fn)

        logger = DummyLogger()

        mock_log_handler = MagicMock()

        logger.attach(trainer, log_handler=mock_log_handler, event_name=event)

        trainer.run(data, max_epochs=n_epochs)

        mock_log_handler.assert_called_with(trainer, logger, event)
        assert mock_log_handler.call_count == n_calls

    _test(Events.ITERATION_STARTED, len(data) * n_epochs)
    _test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
    _test(Events.EPOCH_STARTED, n_epochs)
    _test(Events.EPOCH_COMPLETED, n_epochs)
    _test(Events.STARTED, 1)
    _test(Events.COMPLETED, 1)

    _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs) 
示例29
def get_prepared_engine(true_event_handler_time):
    dummy_trainer = Engine(_do_nothing_update_fn)

    @dummy_trainer.on(Events.STARTED)
    def delay_start(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.COMPLETED)
    def delay_complete(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.EPOCH_STARTED)
    def delay_epoch_start(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.EPOCH_COMPLETED)
    def delay_epoch_complete(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.ITERATION_STARTED)
    def delay_iter_start(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.ITERATION_COMPLETED)
    def delay_iter_complete(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.GET_BATCH_STARTED)
    def delay_get_batch_started(engine):
        time.sleep(true_event_handler_time)

    @dummy_trainer.on(Events.GET_BATCH_COMPLETED)
    def delay_get_batch_completed(engine):
        time.sleep(true_event_handler_time)

    return dummy_trainer 
示例30
def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None):
        """ Register callbacks to control the timer.

        Args:
            engine (ignite.engine.Engine):
                Engine that this timer will be attached to
            start (ignite.engine.Events):
                Event which should start (reset) the timer
            pause (ignite.engine.Events):
                Event which should pause the timer
            resume (ignite.engine.Events, optional):
                Event which should resume the timer
            step (ignite.engine.Events, optional):
                Event which should call the `step` method of the counter

        Returns:
            self (Timer)

        """

        engine.add_event_handler(start, self.reset)
        engine.add_event_handler(pause, self.pause)

        if resume is not None:
            engine.add_event_handler(resume, self.resume)

        if step is not None:
            engine.add_event_handler(step, self.step)

        return self