Python源码示例:torch.get_rng_state()

示例1
def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        # create random image that is consistent with the index id
        rng_state = torch.get_rng_state()
        torch.manual_seed(index + self.random_offset)
        img = torch.randn(*self.image_size)
        target = torch.Tensor(1).random_(0, self.num_classes)[0]
        torch.set_rng_state(rng_state)

        # convert to PIL Image
        img = transforms.ToPILImage()(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target 
示例2
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例3
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例4
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例5
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例6
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例7
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例8
def save_rng_states(device: torch.device,
                    rng_states: Deque[RNGStates],
                    ) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[ByteTensor]
    if device.type == 'cuda':
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state)) 
示例9
def test_get_set_device_states(device, enabled):
    shape = (1, 1, 10, 10)
    if not torch.cuda.is_available() and device == 'cuda':
        pytest.skip('This test requires a GPU to be available')
    X = torch.ones(shape, device=device)
    devices, states = get_device_states(X)
    assert len(states) == (1 if device == 'cuda' else 0)
    assert len(devices) == (1 if device == 'cuda' else 0)
    cpu_rng_state = torch.get_rng_state()
    Y = X * torch.rand(shape, device=device)
    with torch.random.fork_rng(devices=devices, enabled=True):
        if enabled:
            if device == 'cpu':
                torch.set_rng_state(cpu_rng_state)
            else:
                set_device_states(devices=devices, states=states)
        Y2 = X * torch.rand(shape, device=device)
    assert torch.equal(Y, Y2) == enabled 
示例10
def setUp(self):
        if os.getenv("unlock_seed") is None or os.getenv("unlock_seed").lower() == "false":
            self.rng_state = torch.get_rng_state()
            torch.manual_seed(1)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(1)
            random.seed(1)

        mats = torch.randn(5, 4, 4)
        mats = mats @ mats.transpose(-1, -2)
        mats.div_(5).add_(torch.eye(4).unsqueeze_(0))
        vecs = torch.randn(5, 4, 6)
        self.mats = mats.detach().clone().requires_grad_(True)
        self.mats_clone = mats.detach().clone().requires_grad_(True)
        self.vecs = vecs.detach().clone().requires_grad_(True)
        self.vecs_clone = vecs.detach().clone().requires_grad_(True) 
示例11
def setUp(self):
        if os.getenv("unlock_seed") is None or os.getenv("unlock_seed").lower() == "false":
            self.rng_state = torch.get_rng_state()
            torch.manual_seed(0)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(0)
            random.seed(0)

        mats = torch.randn(2, 3, 4, 4)
        mats = mats @ mats.transpose(-1, -2)
        mats.div_(5).add_(torch.eye(4).view(1, 1, 4, 4))
        vecs = torch.randn(2, 3, 4, 6)
        self.mats = mats.detach().clone().requires_grad_(True)
        self.mats_clone = mats.detach().clone().requires_grad_(True)
        self.vecs = vecs.detach().clone().requires_grad_(True)
        self.vecs_clone = vecs.detach().clone().requires_grad_(True) 
示例12
def save_states(self):
        """Saves the states inside a checkpoint associated with ``epoch``."""
        checkpoint_data = dict()
        if isinstance(self.model, torch.nn.DataParallel):
            checkpoint_data['model'] = self.model.module.state_dict()
        else:
            checkpoint_data['model'] = self.model.state_dict()
        checkpoint_data['optimizer'] = self.optimizer.state_dict()
        checkpoint_data['random_states'] = (
            random.getstate(), np.random.get_state(), torch.get_rng_state(), torch.cuda.get_rng_state() if
            torch.cuda.is_available() else None
        )
        checkpoint_data['counters'] = self.counters
        checkpoint_data['losses_epoch'] = self.losses_epoch
        checkpoint_data['losses_it'] = self.losses_it
        checkpoint_data.update(self.save_states_others())
        self.experiment.checkpoint_save(checkpoint_data, self.counters['epoch']) 
示例13
def get_checkpoint(S, stop_conds, rng=None, get_state=True):
    """
    Save the necessary information into a dictionary
    """

    m = {}
    m['ninitfeats'] = S.ninitfeats
    m['x0'] = S.x0
    x = S.x.clone().cpu().detach()
    m['feats'] = np.where(x.numpy() >= 0)[0]
    m.update({k: v[0] for k, v in stop_conds.items()})
    if get_state:
        m.update({constants.Checkpoint.MODEL: S.state_dict(),
                  constants.Checkpoint.OPT: S.opt_train.state_dict(),
                  constants.Checkpoint.RNG: torch.get_rng_state(),
                  })
    if rng:
        m.update({'rng_state': rng.get_state()})

    return m 
示例14
def save_model(self, epochs=-1, optimisers=None, save_dir=None, name=ALICE, timestamp=None):
        '''
        Method to persist the model
        '''
        if not timestamp:
            timestamp = str(int(time()))
        state = {
            EPOCHS: epochs + 1,
            STATE_DICT: self.state_dict(),
            OPTIMISER: [optimiser.state_dict() for optimiser in optimisers],
            NP_RANDOM_STATE: np.random.get_state(),
            PYTHON_RANDOM_STATE: random.getstate(),
            PYTORCH_RANDOM_STATE: torch.get_rng_state()
        }
        path = os.path.join(save_dir,
                            name + "_model_timestamp_" + timestamp + ".tar")
        torch.save(state, path)
        print("saved model to path = {}".format(path)) 
示例15
def save_model(self, epochs=-1, optimisers=None, save_dir=None, name=ALICE, timestamp=None):
        '''
        Method to persist the model
        '''
        if not timestamp:
            timestamp = str(int(time()))
        state = {
            EPOCHS: epochs + 1,
            STATE_DICT: self.state_dict(),
            OPTIMISER: [optimiser.state_dict() for optimiser in optimisers],
            NP_RANDOM_STATE: np.random.get_state(),
            PYTHON_RANDOM_STATE: random.getstate(),
            PYTORCH_RANDOM_STATE: torch.get_rng_state()
        }
        path = os.path.join(save_dir,
                            name + "_model_timestamp_" + timestamp + ".tar")
        torch.save(state, path)
        print("saved model to path = {}".format(path)) 
示例16
def with_torch_seed(seed):
    assert isinstance(seed, int)
    rng_state = torch.get_rng_state()
    cuda_rng_state = torch.cuda.get_rng_state()
    set_torch_seed(seed)
    yield
    torch.set_rng_state(rng_state)
    torch.cuda.set_rng_state(cuda_rng_state) 
示例17
def torch_seed(seed: Optional[int]):
    """Context manager which seeds the PyTorch PRNG with the specified seed and
    restores the state afterward. Setting seed to None is equivalent to running
    the code without the context manager."""
    if seed is None:
        yield
        return
    state = torch.get_rng_state()
    torch.manual_seed(seed)
    try:
        yield
    finally:
        torch.set_rng_state(state) 
示例18
def write_dummy_file(filename, num_examples, maxlen):
    rng_state = torch.get_rng_state()
    torch.manual_seed(0)
    data = torch.rand(num_examples * maxlen)
    data = 97 + torch.floor(26 * data).int()
    with open(filename, "w") as h:
        offset = 0
        for _ in range(num_examples):
            ex_len = random.randint(1, maxlen)
            ex_str = " ".join(map(chr, data[offset : offset + ex_len]))
            print(ex_str, file=h)
            offset += ex_len
    torch.set_rng_state(rng_state) 
示例19
def __enter__(self):
        self.old_rng_state = torch.get_rng_state()
        torch.manual_seed(self.rng_seed) 
示例20
def checkpoint(acc, epoch):
    # Save checkpoint.
    print('Saving..')
    state = {
        'net': net,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7.' + args.sess + '_' + str(args.seed)) 
示例21
def checkpoint(acc, epoch):
    # Save checkpoint.
    print('Saving..')
    state = {
        'net': net,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + args.arch + '_' + args.sess + '_' + str(args.seed) + '.ckpt') 
示例22
def get_device_states(*args):
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))

    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_devices, fwd_gpu_states 
示例23
def save_rng_state(file_name):
    rng_state = torch.get_rng_state()
    torch_save(rng_state, file_name) 
示例24
def getstate(self) -> np.ndarray:
        return torch.get_rng_state().numpy() 
示例25
def freeze_rng_state():
    rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        cuda_rng_state = torch.cuda.get_rng_state()
    yield
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(cuda_rng_state)
    torch.set_rng_state(rng_state) 
示例26
def save_checkpoint(acc, epoch):
    print('=====> Saving checkpoint...')
    state = {
        'model': model,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, args.save_dir + args.name + '_epoch' + str(epoch) + '.ckpt') 
示例27
def save_checkpoint(acc, epoch):
    print('=====> Saving checkpoint...')
    state = {
        'model': model,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, args.save_dir + args.name + '_epoch' + str(epoch) + '.ckpt')


# Decrease the learning rate at 100 and 150 epoch 
示例28
def forward(ctx, run_function, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        if preserve_rng_state:
            # We can't know if the user will transfer some args from the host
            # to the device during their run_fn.  Therefore, we stash both
            # the cpu and cuda rng states unconditionally.
            #
            # TODO:
            # We also can't know if the run_fn will internally move some args to a device
            # other than the current device, which would require logic to preserve
            # rng states for those devices as well.  We could paranoically stash and restore
            # ALL the rng states for all visible devices, but that seems very wasteful for
            # most cases.
            ctx.fwd_cpu_rng_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*args)
        return outputs 
示例29
def setUp(self):
        if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
            self.rng_state = torch.get_rng_state()
            torch.manual_seed(0)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(0)
            random.seed(0) 
示例30
def setUp(self):
        if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
            self.rng_state = torch.get_rng_state()
            torch.manual_seed(0)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(0)
            random.seed(0)