Python源码示例:torch.utils.data()

示例1
def trainBatch(net, criterion, optimizer):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    cost = criterion(preds, text, preds_size, length) / batch_size
    crnn.zero_grad()
    cost.backward()
    optimizer.step()
    return cost 
示例2
def run(args, testset, action):
    if not torch.cuda.is_available():
        args.device = 'cpu'
    args.device = torch.device(args.device)

    LOGGER.debug('Testing (PID=%d), %s', os.getpid(), args)

    model = action.create_model()
    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
    model.to(args.device)

    # dataloader
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=1, shuffle=False, num_workers=args.workers)

    # testing
    LOGGER.debug('tests, begin')
    action.eval_1(model, testloader, args.device)
    LOGGER.debug('tests, end') 
示例3
def eval_1(self, model, testloader, device):
        model.eval()
        with open(self.filename, 'w') as fout:
            self.eval_1__header(fout)
            with torch.no_grad():
                for i, data in enumerate(testloader):
                    p0, p1, igt = data
                    res = self.do_estimate(p0, p1, model, device) # --> [1, 4, 4]
                    ig_gt = igt.cpu().contiguous().view(-1, 4, 4) # --> [1, 4, 4]
                    g_hat = res.cpu().contiguous().view(-1, 4, 4) # --> [1, 4, 4]

                    dg = g_hat.bmm(ig_gt) # if correct, dg == identity matrix.
                    dx = ptlk.se3.log(dg) # --> [1, 6] (if corerct, dx == zero vector)
                    dn = dx.norm(p=2, dim=1) # --> [1]
                    dm = dn.mean()

                    self.eval_1__write(fout, ig_gt, g_hat)
                    LOGGER.info('test, %d/%d, %f', i, len(testloader), dm) 
示例4
def eval_1(self, model, testloader, device):
        model.eval()
        vloss = 0.0
        gloss = 0.0
        count = 0
        with torch.no_grad():
            for i, data in enumerate(testloader):
                loss, loss_g = self.compute_loss(model, data, device)

                vloss1 = loss.item()
                vloss += vloss1
                gloss1 = loss_g.item()
                gloss += gloss1
                count += 1

        ave_vloss = float(vloss)/count
        ave_gloss = float(gloss)/count
        return ave_vloss, ave_gloss 
示例5
def train_1(self, model, trainloader, optimizer, device):
        model.train()
        vloss = 0.0
        pred  = 0.0
        count = 0
        for i, data in enumerate(trainloader):
            target, output, loss = self.compute_loss(model, data, device)
            # forward + backward + optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss1 = loss.item()
            vloss += loss1
            count += output.size(0)

            _, pred1 = output.max(dim=1)
            ag = (pred1 == target)
            am = ag.sum()
            pred += am.item()

        running_loss = float(vloss)/count
        accuracy = float(pred)/count
        return running_loss, accuracy 
示例6
def eval_1(self, model, testloader, device):
        model.eval()
        vloss = 0.0
        pred  = 0.0
        count = 0
        with torch.no_grad():
            for i, data in enumerate(testloader):
                target, output, loss = self.compute_loss(model, data, device)

                loss1 = loss.item()
                vloss += loss1
                count += output.size(0)

                _, pred1 = output.max(dim=1)
                ag = (pred1 == target)
                am = ag.sum()
                pred += am.item()

        ave_loss = float(vloss)/count
        accuracy = float(pred)/count
        return ave_loss, accuracy 
示例7
def set_input(self, input:torch.Tensor):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ## 
示例8
def make_batch_data_sampler(
    dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0
):
    if aspect_grouping:
        if not isinstance(aspect_grouping, (list, tuple)):
            aspect_grouping = [aspect_grouping]
        aspect_ratios = _compute_aspect_ratios(dataset)
        group_ids = _quantize(aspect_ratios, aspect_grouping)
        batch_sampler = samplers.GroupedBatchSampler(
            sampler, group_ids, images_per_batch, drop_uneven=False
        )
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler, images_per_batch, drop_last=False
        )
    if num_iters is not None:
        batch_sampler = samplers.IterationBasedBatchSampler(
            batch_sampler, num_iters, start_iter
        )
    return batch_sampler 
示例9
def __init__(self, data, transform=lambda data: data, one_hot=None, shuffle=False, dir=None):
        """
        Load the cached data (.pkl) into memory.
        :author 申瑞珉 (Ruimin Shen)
        :param data: A list contains the data samples (dict).
        :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict.
        :param one_hot: If a int value (total number of classes) is given, the class label (key "cls") will be generated in a one-hot format.
        :param shuffle: Shuffle the loaded dataset.
        :param dir: The directory to store the exception data.
        """
        self.data = data
        if shuffle:
            random.shuffle(self.data)
        self.transform = transform
        self.one_hot = None if one_hot is None else sklearn.preprocessing.OneHotEncoder(one_hot, dtype=np.float32)
        self.dir = dir 
示例10
def __init__(self, resize, sizes, maintain=1, transform_image=lambda image: image, transform_tensor=None, dir=None):
        """
        Unify multiple data samples (e.g., resize images into the same size, and padding bounding box labels into the same number) to form a batch.
        :author 申瑞珉 (Ruimin Shen)
        :param resize: A function to resize the image and labels.
        :param sizes: The image sizes to be randomly choosed.
        :param maintain: How many times a size to be maintained.
        :param transform_image: A function to transform the resized image.
        :param transform_tensor: A function to standardize a image into a tensor.
        :param dir: The directory to store the exception data.
        """
        self.resize = resize
        self.sizes = sizes
        assert maintain > 0
        self.maintain = maintain
        self._maintain = maintain
        self.transform_image = transform_image
        self.transform_tensor = transform_tensor
        self.dir = dir 
示例11
def __call__(self, batch):
        height, width = self.next_size()
        dim = max(len(data['cls']) for data in batch)
        _batch = []
        for data in batch:
            try:
                data = self.resize(data, height, width)
                data['image'] = self.transform_image(data['image'])
                data = padding_labels(data, dim)
                if self.transform_tensor is not None:
                    data['tensor'] = self.transform_tensor(data['image'])
                _batch.append(data)
            except:
                if self.dir is not None:
                    os.makedirs(self.dir, exist_ok=True)
                    name = self.__module__ + '.' + type(self).__name__
                    with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                raise
        return torch.utils.data.dataloader.default_collate(_batch) 
示例12
def get_loader(self):
        paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('eval', 'phase').split()]
        dataset = utils.data.Dataset(utils.data.load_pickles(paths))
        logging.info('num_examples=%d' % len(dataset))
        size = tuple(map(int, self.config.get('image', 'size').split()))
        try:
            workers = self.config.getint('data', 'workers')
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        collate_fn = utils.data.Collate(
            transform.parse_transform(self.config, self.config.get('transform', 'resize_eval')),
            [size],
            transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_test').split()),
            transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()),
        )
        return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers, collate_fn=collate_fn) 
示例13
def __init__(self, args, config):
        self.args = args
        self.config = config
        self.model_dir = utils.get_model_dir(config)
        self.category = utils.get_category(config)
        self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
        self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), self.anchors, len(self.category))
        self.dnn.eval()
        logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values())))
        if torch.cuda.is_available():
            self.dnn.cuda()
        self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
        output = self.dnn(torch.autograd.Variable(utils.ensure_device(torch.zeros(1, 3, self.height, self.width)), volatile=True))
        _, _, self.rows, self.cols = output.size()
        self.i, self.j = self.rows // 2, self.cols // 2
        self.output = output[:, :, self.i, self.j]
        dataset = Dataset(self.height, self.width)
        try:
            workers = self.config.getint('data', 'workers')
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers) 
示例14
def __call__(self):
        changed = np.zeros([self.height, self.width], np.bool)
        for yx in tqdm.tqdm(self.loader):
            batch_size = yx.size(0)
            tensor = torch.zeros(batch_size, 3, self.height, self.width)
            for i, _yx in enumerate(torch.unbind(yx)):
                y, x = torch.unbind(_yx)
                tensor[i, :, y, x] = 1
            tensor = utils.ensure_device(tensor)
            output = self.dnn(torch.autograd.Variable(tensor, volatile=True))
            output = output[:, :, self.i, self.j]
            cmp = output == self.output
            cmp = torch.prod(cmp, -1).data
            for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)):
                y, x = torch.unbind(_yx)
                changed[y, x] = c
        return changed 
示例15
def get_loader(self):
        paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('train', 'phase').split()]
        dataset = utils.data.Dataset(
            utils.data.load_pickles(paths),
            transform=transform.augmentation.get_transform(self.config, self.config.get('transform', 'augmentation').split()),
            one_hot=None if self.config.getboolean('train', 'cross_entropy') else len(self.category),
            shuffle=self.config.getboolean('data', 'shuffle'),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        logging.info('num_examples=%d' % len(dataset))
        try:
            workers = self.config.getint('data', 'workers')
            if torch.cuda.is_available():
                workers = workers * torch.cuda.device_count()
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        collate_fn = utils.data.Collate(
            transform.parse_transform(self.config, self.config.get('transform', 'resize_train')),
            utils.train.load_sizes(self.config),
            maintain=self.config.getint('data', 'maintain'),
            transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_train').split()),
            transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size * torch.cuda.device_count() if torch.cuda.is_available() else self.args.batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available()) 
示例16
def iterate(self, data):
        for key in data:
            t = data[key]
            if torch.is_tensor(t):
                data[key] = utils.ensure_device(t)
        tensor = torch.autograd.Variable(data['tensor'])
        pred = pybenchmark.profile('inference')(model._inference)(self.inference, tensor)
        height, width = data['image'].size()[1:3]
        rows, cols = pred['feature'].size()[-2:]
        loss, debug = pybenchmark.profile('loss')(model.loss)(self.anchors, norm_data(data, height, width, rows, cols), pred, self.config.getfloat('model', 'threshold'))
        loss_hparam = {key: loss[key] * self.config.getfloat('hparam', key) for key in loss}
        loss_total = sum(loss_hparam.values())
        self.optimizer.zero_grad()
        loss_total.backward()
        try:
            clip = self.config.getfloat('train', 'clip')
            nn.utils.clip_grad_norm(self.inference.parameters(), clip)
        except configparser.NoOptionError:
            pass
        self.optimizer.step()
        return dict(
            height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred, debug=debug,
            loss_total=loss_total, loss=loss, loss_hparam=loss_hparam,
        ) 
示例17
def __getitem__(self, idx):
        '''

        :param idx: Index of the image file
        :return: returns the image and corresponding label file.
        '''
        image_name = self.imList[idx]
        label_name = self.labelList[idx]
        image = cv2.imread(image_name)
        label = cv2.imread(label_name, 0)
        label_bool = 255 * ((label > 200).astype(np.uint8))

        if self.transform:
            [image, label] = self.transform(image, label_bool)
        if self.edge:
            np_label = 255 * label.data.numpy().astype(np.uint8)
            kernel = np.ones((self.kernel_size , self.kernel_size ), np.uint8)
            erosion = cv2.erode(np_label, kernel, iterations=1)
            dilation = cv2.dilate(np_label, kernel, iterations=1)
            boundary = dilation - erosion
            edgemap = 255 * torch.ones_like(label)
            edgemap[torch.from_numpy(boundary) > 0] = label[torch.from_numpy(boundary) > 0]
            return (image, label, edgemap)
        else:
            return (image, label) 
示例18
def make_batch_data_sampler(
    dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0
):
    if aspect_grouping:
        if not isinstance(aspect_grouping, (list, tuple)):
            aspect_grouping = [aspect_grouping]
        aspect_ratios = _compute_aspect_ratios(dataset)
        group_ids = _quantize(aspect_ratios, aspect_grouping)
        batch_sampler = samplers.GroupedBatchSampler(
            sampler, group_ids, images_per_batch, drop_uneven=False
        )
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler, images_per_batch, drop_last=False
        )
    if num_iters is not None:
        batch_sampler = samplers.IterationBasedBatchSampler(
            batch_sampler, num_iters, start_iter
        )
    return batch_sampler 
示例19
def __init__(self, root, num=10, split="train", download=False):
        assert num in {10, 30, 50}
        assert split in {"train", "test", "valid"}
        self.num = num
        self.split = split
        self.root = root
        if download:
            self.download()
        else:
            self._check_integrity()
        name = {"train": "train", "test": "test", "valid": "dev"}[split]
        self.data = pickle.load(open(os.path.join(root, self._suffix, name), "rb"))
        self.id2word = pickle.load(
            open(os.path.join(root, self._suffix, "id_to_word"), "rb")
        )
        self.word2id = pickle.load(
            open(os.path.join(root, self._suffix, "word_to_id"), "rb")
        ) 
示例20
def train(train_loader, model, loss_fn, optimizer, epoch, T, opts, start_time):

    loss_val = 0.
    num_iters = len(train_loader)
    num_prints = min([5, num_iters])
    model.train()
    for i, (inputs, targets) in enumerate(train_loader):
        inputs = Variable(inputs.cuda()) if opts.gpu else Variable(inputs)
        targets = targets.cuda() if opts.gpu else targets
        
        optimizer.zero_grad()
        if opts.method == 'TD':
            batch_loss = Variable(torch.zeros(1).cuda()) if opts.gpu else Variable(torch.zeros(1))
            for m, sub in enumerate(model):
                logits = sub(inputs)
                batch_loss += loss_fn(logits, targets, m)
        else:
            logits = model(inputs)
            batch_loss = loss_fn(logits, targets)
        loss_val += batch_loss.data[0]
        batch_loss.backward()
        optimizer.step()
        
        print_me = (opts.batch_size > 0 and \
                    ((i+1) % (num_iters // num_prints) == 0 or i == 0 or i == num_iters-1)) \
                   or (opts.batch_size == 0 and (epoch == 1 or epoch % opts.save_freq == 0))
        if print_me:
            print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=opts.num_epochs), end='')
            print('{iter:3d}/{num_iters:3d} i; '.format(iter=i+1, num_iters=num_iters), end='')
            print('lr: {lr:.0e}; '.format(lr=optimizer.param_groups[0]['lr']), end='')
            print('bl: {loss:9.3f}; '.format(loss=batch_loss.data[0]), end='')
            print('ml: {loss:9.3f}; '.format(loss=float(loss_val)/(i+1)), end='')
            print('{time:8.3f} s'.format(time=time.time()-start_time))
    
    return loss_val 
示例21
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0) 
示例22
def get_datasets(args):

    cinfo = None
    if args.categoryfile:
        #categories = numpy.loadtxt(args.categoryfile, dtype=str, delimiter="\n").tolist()
        categories = [line.rstrip('\n') for line in open(args.categoryfile)]
        categories.sort()
        c_to_idx = {categories[i]: i for i in range(len(categories))}
        cinfo = (categories, c_to_idx)

    perturbations = None
    fmt_trans = False
    if args.perturbations:
        perturbations = numpy.loadtxt(args.perturbations, delimiter=',')
    if args.format == 'wt':
        fmt_trans = True

    if args.dataset_type == 'modelnet':
        transform = torchvision.transforms.Compose([\
                ptlk.data.transforms.Mesh2Points(),\
                ptlk.data.transforms.OnUnitCube(),\
            ])

        testdata = ptlk.data.datasets.ModelNet(args.dataset_path, train=0, transform=transform, classinfo=cinfo)

        testset = ptlk.data.datasets.CADset4tracking_fixed_perturbation(testdata,\
                        perturbations, fmt_trans=fmt_trans)

    return testset 
示例23
def compute_loss(self, model, data, device):
        p0, p1, igt = data
        p0 = p0.to(device) # template
        p1 = p1.to(device) # source
        igt = igt.to(device) # igt: p0 -> p1
        r = ptlk.pointlk.PointLK.do_forward(model, p0, p1, self.max_iter, self.xtol,\
                                            self.p0_zero_mean, self.p1_zero_mean)
        #r = model(p0, p1, self.max_iter)
        est_g = model.g

        loss_g = ptlk.pointlk.PointLK.comp(est_g, igt)

        if self._loss_type == 0:
            loss_r = ptlk.pointlk.PointLK.rsq(r)
            loss = loss_r
        elif self._loss_type == 1:
            loss_r = ptlk.pointlk.PointLK.rsq(r)
            loss = loss_r + loss_g
        elif self._loss_type == 2:
            pr = model.prev_r
            if pr is not None:
                loss_r = ptlk.pointlk.PointLK.rsq(r - pr)
            else:
                loss_r = ptlk.pointlk.PointLK.rsq(r)
            loss = loss_r + loss_g
        else:
            loss = loss_g

        return loss, loss_g 
示例24
def options(argv=None):
    parser = argparse.ArgumentParser(description='ICP')

    # required.
    parser.add_argument('-o', '--outfile', required=True, type=str,
                        metavar='FILENAME', help='output filename (.csv)')
    parser.add_argument('-i', '--dataset-path', required=True, type=str,
                        metavar='PATH', help='path to the input dataset')
    parser.add_argument('-c', '--categoryfile', required=True, type=str,
                        metavar='PATH', help='path to the categories to be tested')
    parser.add_argument('-p', '--perturbations', required=True, type=str,
                        metavar='PATH', help='path to the perturbations')

    # settings for input data
    parser.add_argument('--dataset-type', default='modelnet', choices=['modelnet'],
                        metavar='DATASET', help='dataset type (default: modelnet)')
    parser.add_argument('--format', default='wv', choices=['wv', 'wt'],
                        help='perturbation format (default: wv (twist)) (wt: rotation and translation)') # the output is always in twist format

    # settings for ICP
    parser.add_argument('--max-iter', default=20, type=int,
                        metavar='N', help='max-iter on ICP. (default: 20)')

    # settings for on testing
    parser.add_argument('-l', '--logfile', default='', type=str,
                        metavar='LOGNAME', help='path to logfile (default: null (no logging))')
    parser.add_argument('-j', '--workers', default=4, type=int,
                        metavar='N', help='number of data loading workers (default: 4)')

    args = parser.parse_args(argv)
    return args 
示例25
def run(args, testset, action):
    LOGGER.debug('Testing (PID=%d), %s', os.getpid(), args)

    sys.setrecursionlimit(20000)

    # dataloader
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=1, shuffle=False, num_workers=args.workers)

    # testing
    LOGGER.debug('tests, begin')
    action.eval_1(testloader)
    LOGGER.debug('tests, end') 
示例26
def get_datasets(args):

    cinfo = None
    if args.categoryfile:
        #categories = numpy.loadtxt(args.categoryfile, dtype=str, delimiter="\n").tolist()
        categories = [line.rstrip('\n') for line in open(args.categoryfile)]
        categories.sort()
        c_to_idx = {categories[i]: i for i in range(len(categories))}
        cinfo = (categories, c_to_idx)

    perturbations = None
    fmt_trans = False
    if args.perturbations:
        perturbations = numpy.loadtxt(args.perturbations, delimiter=',')
    if args.format == 'wt':
        fmt_trans = True

    if args.dataset_type == 'modelnet':
        transform = torchvision.transforms.Compose([\
                ptlk.data.transforms.Mesh2Points(),\
                ptlk.data.transforms.OnUnitCube(),\
            ])

        testdata = ptlk.data.datasets.ModelNet(args.dataset_path, train=0, transform=transform, classinfo=cinfo)

        testset = ptlk.data.datasets.CADset4tracking_fixed_perturbation(testdata,\
                        perturbations, fmt_trans=fmt_trans)

    return testset 
示例27
def compute_loss(self, model, data, device):
        points, target = data

        points = points.to(device)
        target = target.to(device)

        output = model(points)
        loss = model.loss(output, target)

        return target, output, loss 
示例28
def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ## 
示例29
def train_one_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            # self.optimize()
            self.optimize_params()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)

    ## 
示例30
def batch_generator(data, batch_size, shuffle=True):
    """Yield elements from data in chunks of batch_size."""
    if shuffle:
        sampler = torch.utils.data.RandomSampler(data)
    else:
        sampler = torch.utils.data.SequentialSampler(data)
    minibatch = []
    for idx in sampler:
        minibatch.append(data[idx])
        if len(minibatch) == batch_size:
            yield minibatch
            minibatch = []
    if minibatch:
        yield minibatch