feature: remove more randomness

pull/154/head
Bryce 1 year ago committed by Bryce Drennan
parent eb40842078
commit 4bc940ddf4

@ -132,6 +132,8 @@ vendorize_kdiffusion:
sed -i '' -e 's#return (x - denoised) / sigma#return ((x - denoised) / sigma.to("cpu")).to(x.device)#g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#return t.neg().exp()#return t.to("cpu").neg().exp().to(self.model.device)#g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#import torchsde##g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#torch.randint(0, 2\*\*63 - 1, \[\])#torch.randint(0, 2**63 - 1, [], device="cpu")#g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#torch.randint_like(x, 2)#torch.randint_like(x, 2, device="cpu")#g' imaginairy/vendored/k_diffusion/sampling.py
make af
vendorize_noodle_soup:

@ -74,7 +74,7 @@ class BatchedBrownianTree:
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
seed = torch.randint(0, 2**63 - 1, [], device="cpu").item()
self.batched = True
try:
assert len(seed) == x.shape[0]
@ -412,7 +412,7 @@ def log_likelihood(
):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
v = torch.randint_like(x, 2, device="cpu") * 2 - 1
fevals = 0
def ode_fn(sigma, x):

Loading…
Cancel
Save