Python源码示例:ignite.engine.Events.EPOCH_COMPLETED

示例1
def test_metrics_print(self):
        tempdir = tempfile.mkdtemp()
        shutil.rmtree(tempdir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return batch + 1.0

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get("acc", 0.1)
            engine.state.metrics["acc"] = current_metric + 0.1

        # set up testing handler
        stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
        stats_handler.attach(engine)
        engine.run(range(3), max_epochs=2)
        # check logging output

        self.assertTrue(os.path.exists(tempdir))
        shutil.rmtree(tempdir) 
示例2
def attach(self, engine, metric_names=None, output_transform=None):
        """
        Attaches the progress bar to an engine object.

        Args:
            engine (Engine): engine object.
            metric_names (list, optional): list of the metrics names to log as the bar progresses
            output_transform (callable, optional): a function to select what you want to print from the engine's
                output. This function may return either a dictionary with entries in the format of ``{name: value}``,
                or a single scalar, which will be displayed with the default name `output`.
        """
        if metric_names is not None and not isinstance(metric_names, list):
            raise TypeError("metric_names should be a list, got {} instead.".format(type(metric_names)))

        if output_transform is not None and not callable(output_transform):
            raise TypeError("output_transform should be a function, got {} instead."
                            .format(type(output_transform)))

        engine.add_event_handler(Events.ITERATION_COMPLETED, self._update, metric_names, output_transform)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._close) 
示例3
def test_terminate_at_end_of_epoch_stops_run():
    max_epochs = 5
    last_epoch_to_run = 3

    engine = Engine(MagicMock(return_value=1))

    def end_of_epoch_handler(engine):
        if engine.state.epoch == last_epoch_to_run:
            engine.terminate()

    engine.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not engine.should_terminate

    state = engine.run([1], max_epochs=max_epochs)

    assert state.epoch == last_epoch_to_run
    assert engine.should_terminate 
示例4
def test_time_stored_in_state():
    def _test(data, max_epochs, epoch_length):
        sleep_time = 0.01
        engine = Engine(lambda e, b: time.sleep(sleep_time))

        def check_epoch_time(engine):
            assert engine.state.times[Events.EPOCH_COMPLETED.name] >= sleep_time * epoch_length

        def check_completed_time(engine):
            assert engine.state.times[Events.COMPLETED.name] >= sleep_time * epoch_length * max_epochs

        engine.add_event_handler(Events.EPOCH_COMPLETED, lambda e: check_epoch_time(e))
        engine.add_event_handler(Events.COMPLETED, lambda e: check_completed_time(e))

        engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

    _test(list(range(100)), max_epochs=2, epoch_length=100)
    _test(list(range(200)), max_epochs=2, epoch_length=100)
    _test(list(range(200)), max_epochs=5, epoch_length=100) 
示例5
def test_has_handler_on_callable_events():
    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED)

    engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) 
示例6
def test_state_custom_attrs_init():
    def _test(with_load_state_dict=False):
        engine = Engine(lambda e, b: None)
        engine.state.alpha = 0.0
        engine.state.beta = 1.0

        if with_load_state_dict:
            engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5})

        @engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED)
        def check_custom_attr():
            assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0
            assert hasattr(engine.state, "beta") and engine.state.beta == 1.0

        engine.run([0, 1, 2, 3, 4], max_epochs=5)

    _test()
    _test(with_load_state_dict=True) 
示例7
def test_with_engine_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 7
    assert trainer.state.epoch == 7 
示例8
def test_with_engine_early_stopping_on_plateau():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    def score_function(engine):
        return 0.047

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=4, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 5
    assert trainer.state.epoch == 5 
示例9
def test_with_engine_no_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 10
    assert trainer.state.epoch == 10 
示例10
def test_save_best_model_by_val_score(dirname):

    trainer = Engine(lambda e, b: None)
    evaluator = Engine(lambda e, b: None)
    model = DummyModel()

    acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        evaluator.run([0, 1])

    @evaluator.on(Events.EPOCH_COMPLETED)
    def set_eval_metric(engine):
        engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}

    save_best_model_by_val_score(dirname, evaluator, model, metric_name="acc", n_saved=2, trainer=trainer)

    trainer.run([0, 1], max_epochs=len(acc_scores))

    assert set(os.listdir(dirname)) == {"best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"} 
示例11
def test_integration(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    with pytest.warns(UserWarning, match="TrainsSaver: running in bypass mode"):
        TrainsLogger.set_bypass_mode(True)
        logger = TrainsLogger(output_uri=dirname)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            logger.trains_logger.report_scalar(title="", series="", value="test_value", iteration=global_step)

        logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        trainer.run(data, max_epochs=n_epochs)
        logger.close() 
示例12
def test_integration():

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    plx_logger = PolyaxonLogger()

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        logger.log_metrics(step=global_step, **{"{}".format("test_value"): global_step})

    plx_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    trainer.run(data, max_epochs=n_epochs) 
示例13
def test_integration_as_context_manager():

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    with PolyaxonLogger() as plx_logger:

        trainer = Engine(update_fn)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            logger.log_metrics(step=global_step, **{"{}".format("test_value"): global_step})

        plx_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        trainer.run(data, max_epochs=n_epochs) 
示例14
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)

    with pytest.raises(ValueError, match="should not be a filtered event"):
        pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10)) 
示例15
def test_integration():
    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    npt_logger = NeptuneLogger(offline_mode=True)

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        logger.log_metric("test_value", global_step, global_step)

    npt_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    trainer.run(data, max_epochs=n_epochs)
    npt_logger.close() 
示例16
def test_integration_as_context_manager():
    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    with NeptuneLogger(offline_mode=True) as npt_logger:
        trainer = Engine(update_fn)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            logger.log_metric("test_value", global_step, global_step)

        npt_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        trainer.run(data, max_epochs=n_epochs) 
示例17
def test_event_handler_epoch_completed():
    true_event_handler_time = 0.1
    true_max_epochs = 2
    true_num_iters = 1

    profiler = BasicTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.EPOCH_COMPLETED)
    def delay_epoch_complete(engine):
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = results["event_handlers_stats"]["EPOCH_COMPLETED"]

    assert event_results["min/index"][0] == approx(true_event_handler_time, abs=1e-1)
    assert event_results["max/index"][0] == approx(true_event_handler_time, abs=1e-1)
    assert event_results["mean"] == approx(true_event_handler_time, abs=1e-1)
    assert event_results["std"] == approx(0.0, abs=1e-1)
    assert event_results["total"] == approx(true_max_epochs * true_event_handler_time, abs=1e-1) 
示例18
def attach(self, engine, name):
        engine.add_event_handler(Events.EPOCH_STARTED, self.started)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) 
示例19
def test_metrics_writer(self):
        tempdir = tempfile.mkdtemp()
        shutil.rmtree(tempdir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return batch + 1.0

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get("acc", 0.1)
            engine.state.metrics["acc"] = current_metric + 0.1

        # set up testing handler
        writer = SummaryWriter(log_dir=tempdir)
        stats_handler = TensorBoardStatsHandler(
            writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0
        )
        stats_handler.attach(engine)
        engine.run(range(3), max_epochs=2)
        # check logging output
        self.assertTrue(os.path.exists(tempdir))
        self.assertTrue(len(glob.glob(tempdir)) > 0)
        shutil.rmtree(tempdir) 
示例20
def test_metrics_print(self):
        log_stream = StringIO()
        logging.basicConfig(stream=log_stream, level=logging.INFO)
        key_to_handler = "test_logging"
        key_to_print = "testing_metric"

        # set up engine
        def _train_func(engine, batch):
            return torch.tensor(0.0)

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get(key_to_print, 0.1)
            engine.state.metrics[key_to_print] = current_metric + 0.1

        # set up testing handler
        stats_handler = StatsHandler(name=key_to_handler)
        stats_handler.attach(engine)

        engine.run(range(3), max_epochs=2)

        # check logging output
        output_str = log_stream.getvalue()
        grep = re.compile(f".*{key_to_handler}.*")
        has_key_word = re.compile(f".*{key_to_print}.*")
        for idx, line in enumerate(output_str.split("\n")):
            if grep.match(line):
                if idx in [5, 10]:
                    self.assertTrue(has_key_word.match(line)) 
示例21
def test_content(self):
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        data = [0] * 8

        # set up engine
        def _train_func(engine, batch):
            pass

        val_engine = Engine(_train_func)
        train_engine = Engine(_train_func)

        @train_engine.on(Events.EPOCH_COMPLETED)
        def run_validation(engine):
            val_engine.run(data)
            val_engine.state.metrics["val_loss"] = 1

        # set up testing handler
        net = torch.nn.PReLU()

        def _reduce_lr_on_plateau():
            optimizer = torch.optim.SGD(net.parameters(), 0.1)
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1)
            handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"])
            handler.attach(train_engine)
            return lr_scheduler

        def _reduce_on_step():
            optimizer = torch.optim.SGD(net.parameters(), 0.1)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
            handler = LrScheduleHandler(lr_scheduler)
            handler.attach(train_engine)
            return lr_scheduler

        schedulers = _reduce_lr_on_plateau(), _reduce_on_step()

        train_engine.run(data, max_epochs=5)
        for scheduler in schedulers:
            np.testing.assert_allclose(scheduler._last_lr[0], 0.001) 
示例22
def __call__(self, engine, logger, event_name):
        if not isinstance(logger, ChainerUILogger):
            raise RuntimeError(
                '`chainerui.contrib.ignite.handler.OutputHandler` works only '
                'with ChainerUILogger, but set {}'.format(type(logger)))

        metrics = self._setup_output_metrics(engine)
        if not metrics:
            return
        iteration = self.global_step_transform(
            engine, Events.ITERATION_COMPLETED)
        epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED)

        # convert metrics name
        rendered_metrics = {}
        for k, v in metrics.items():
            rendered_metrics['{}/{}'.format(self.tag, k)] = v
        rendered_metrics['iteration'] = iteration
        rendered_metrics['epoch'] = epoch
        if 'elapsed_time' not in rendered_metrics:
            rendered_metrics['elapsed_time'] = _get_time() - logger.start_at

        if self.interval <= 1:
            logger.post_log([rendered_metrics])
            return

        # enable interval, cache metrics
        logger.cache.setdefault(self.tag, []).append(rendered_metrics)
        # select appropriate even set by handler init
        global_count = self.global_step_transform(engine, event_name)
        if global_count % self.interval == 0:
            logger.post_log(logger.cache[self.tag])
            logger.cache[self.tag].clear() 
示例23
def test_empty_metrics(client):
    handler = OutputHandler('test', metric_names='all')

    engine = MagicMock()
    engine.state.metrics = {}
    logger = ChainerUILogger()

    with logger:
        handler(engine, logger, Events.EPOCH_COMPLETED)
    client.post_log.assert_not_called() 
示例24
def test_post_metrics(client):
    handler = OutputHandler('test', metric_names='all')

    metrics = {'loss': 0.1}
    engine = MagicMock()
    engine.state.metrics = metrics
    logger = ChainerUILogger()

    with logger:
        handler(engine, logger, Events.EPOCH_COMPLETED)
    client.post_log.assert_called_once() 
示例25
def test_post_metrics_with_interval(client):

    def stepper(engine, event_name):
        engine.state.step += 1
        return engine.state.step

    handler = OutputHandler(
        'test', metric_names='all', interval_step=2,
        global_step_transform=stepper)

    metrics = {'loss': 0.1}
    engine = MagicMock()
    engine.state.step = 0
    engine.state.metrics = metrics
    logger = ChainerUILogger()

    with logger:
        handler(engine, logger, Events.EPOCH_COMPLETED)
        # metrics is cached and not posted them yet
        client.post_log.assert_not_called()
        assert 'test' in logger.cache
        assert len(logger.cache['test']) == 1

        handler(engine, logger, Events.EPOCH_COMPLETED)
        assert not logger.cache['test']
        client.post_log.assert_called_once()

        handler(engine, logger, Events.EPOCH_COMPLETED)
        assert len(logger.cache['test']) == 1
        client.post_log.assert_called_once()

    # remainder metrics are posted after logger exit
    assert client.post_log.call_count == 2 
示例26
def attach(self, engine, name):
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
        if not engine.has_event_handler(self.started, Events.EPOCH_STARTED):
            engine.add_event_handler(Events.EPOCH_STARTED, self.started)
        if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
            engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) 
示例27
def setup_checkpoint(self, base_model, classifier, setops_model, evaluator):
        """Save checkpoints of the models."""

        checkpoint_handler_acc = ModelCheckpoint(
            self.results_path,
            CKPT_PREFIX,
            score_function=lambda eng: round(
                (eng.state.metrics["fake class acc"] + eng.state.metrics["S class acc"] +
                 eng.state.metrics["I class acc"] + eng.state.metrics["U class acc"]) / 4,
                3
            ),
            score_name="val_acc",
            n_saved=2,
            require_empty=False
        )
        checkpoint_handler_last = ModelCheckpoint(
            self.results_path,
            CKPT_PREFIX,
            save_interval=2,
            n_saved=2,
            require_empty=False
        )
        evaluator.add_event_handler(
            event_name=Events.EPOCH_COMPLETED,
            handler=checkpoint_handler_acc,
            to_save={
                'base_model': base_model.state_dict(),
                'classifier': classifier.state_dict(),
                'setops_model': setops_model.state_dict(),
            }
        )
        evaluator.add_event_handler(
            event_name=Events.EPOCH_COMPLETED,
            handler=checkpoint_handler_last,
            to_save={
                'base_model': base_model.state_dict(),
                'classifier': classifier.state_dict(),
                'setops_model': setops_model.state_dict(),
            }
        ) 
示例28
def __init__(
        self,
        *,
        patience,
        score_function,
        out_of_patience_callback,
        training_engine: Engine,
        validation_engine: Engine,
        module: torch.nn.Module = None,
        optimizer: torch.optim.Optimizer = None,
    ):

        if not callable(score_function):
            raise TypeError("Argument score_function should be a function")

        if patience < 1:
            raise ValueError("Argument patience should be positive integer")

        self.score_function = score_function
        self.out_of_patience_callback = out_of_patience_callback
        self.module = module
        self.optimizer = optimizer

        self.patience = patience
        self.counter = 0

        self.best_score = None
        self.best_module_state_dict = None
        self.best_optimizer_state_dict = None
        self.restore_epoch = None

        self.training_engine = training_engine
        self.validation_engine = validation_engine
        validation_engine.add_event_handler(Events.EPOCH_COMPLETED, self.on_epoch_completed)
        training_engine.add_event_handler(Events.COMPLETED, self.on_completed) 
示例29
def objective(trial):
    # Create a convolutional neural network.
    model = Net(trial)

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
        model.cuda(device)

    optimizer = Adam(model.parameters())
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
    evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()}, device=device)

    # Register a pruning handler to the evaluator.
    pruning_handler = optuna.integration.PyTorchIgnitePruningHandler(trial, "accuracy", trainer)
    evaluator.add_event_handler(Events.COMPLETED, pruning_handler)

    # Load MNIST dataset.
    train_loader, val_loader = get_data_loaders(TRAIN_BATCH_SIZE, VAL_BATCH_SIZE)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(engine):
        evaluator.run(val_loader)
        validation_acc = evaluator.state.metrics["accuracy"]
        print("Epoch: {} Validation accuracy: {:.2f}".format(engine.state.epoch, validation_acc))

    trainer.run(train_loader, max_epochs=EPOCHS)

    evaluator.run(val_loader)
    return evaluator.state.metrics["accuracy"] 
示例30
def __init__(self):
        super(EpochWise, self).__init__(
            started=Events.EPOCH_STARTED,
            completed=Events.EPOCH_COMPLETED,
            iteration_completed=Events.ITERATION_COMPLETED,
        )