Python源码示例:torch.utils.create_exp_dir()

示例1
def __init__(self):
        super(Helper, self).__init__()

        self.args._save = copy(self.args.save)
        self.args.save = '{}/{}/{}/{}_{}-{}'.format(self.args.save,
                                                    self.args.space,
                                                    self.args.dataset,
                                                    self.args.search_dp,
                                                    self.args.search_wd,
                                                    self.args.job_id)

        utils.create_exp_dir(self.args.save)

        config_filename = os.path.join(self.args._save, 'config.yaml')
        if not os.path.exists(config_filename):
            with open(config_filename, 'w') as f:
                yaml.dump(self.args_to_log, f, default_flow_style=False)

        if self.args.dataset != 'cifar100':
            self.args.n_classes = 10
        else:
            self.args.n_classes = 100 
示例2
def __init__(self):
        super(Helper, self).__init__()

        self.args._save = copy(self.args.save)
        self.args.save = '{}/{}/{}/{}_{}-{}'.format(self.args.save,
                                                    self.args.space,
                                                    self.args.dataset,
                                                    self.args.drop_path_prob,
                                                    self.args.weight_decay,
                                                    self.args.job_id)

        utils.create_exp_dir(self.args.save)

        config_filename = os.path.join(self.args._save, 'config.yaml')
        if not os.path.exists(config_filename):
            with open(config_filename, 'w') as f:
                yaml.dump(self.args_to_log, f, default_flow_style=False)

        if self.args.dataset != 'cifar100':
            self.args.n_classes = 10
        else:
            self.args.n_classes = 100

        # set cutout to False if the drop_prob is 0
        if self.args.drop_path_prob == 0:
            self.args.cutout = False 
示例3
def initialize_run(self, sub_dir_path=None):
        args = self.args
        utils = project_utils
        if not self.args.continue_train:
            self.sub_directory_path = 'WeightSharingNasBenchNetRandom-{}_SEED_{}'.format(self.args.save, self.args.seed)
            self.exp_dir = os.path.join(self.args.main_path, self.sub_directory_path)
            utils.create_exp_dir(self.exp_dir)
        if self.args.visualize:
            self.viz_dir_path = utils.create_viz_dir(self.exp_dir)

        if self.args.tensorboard:
            self.tb_dir = self.exp_dir
            tboard_dir = os.path.join(self.args.tboard_dir, self.sub_directory_path)
            self.writer = SummaryWriter(tboard_dir)

        if self.args.debug:
            torch.autograd.set_detect_anomaly(True)

        # Set logger.
        self.logger = utils.get_logger(
            "train_search",
            file_handler=utils.get_file_handler(os.path.join(self.exp_dir, 'log.txt')),
            level=logging.INFO if not args.debug else logging.DEBUG
        )
        logging.info(f"setting random seed as {args.seed}")
        utils.torch_random_seed(args.seed)
        logging.info('gpu number = %d' % args.gpus)
        logging.info("args = %s", args)

        criterion = nn.CrossEntropyLoss().cuda()
        eval_criterion = nn.CrossEntropyLoss().cuda()
        self.eval_loss = eval_criterion

        train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_length if args.cutout else None)
        train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=valid_transform)
        test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(args.nao_search_config.ratio * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=args.nao_search_config.child_eval_batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        test_queue = torch.utils.data.DataLoader(
            test_data, batch_size=args.evaluate_batch_size,
            shuffle=False, pin_memory=True, num_workers=8)

        return train_queue, valid_queue, test_queue, criterion, eval_criterion