|
|
|
@ -74,7 +74,7 @@ class BatchedBrownianTree:
|
|
|
|
|
t0, t1, self.sign = self.sort(t0, t1)
|
|
|
|
|
w0 = kwargs.get("w0", torch.zeros_like(x))
|
|
|
|
|
if seed is None:
|
|
|
|
|
seed = torch.randint(0, 2**63 - 1, []).item()
|
|
|
|
|
seed = torch.randint(0, 2**63 - 1, [], device="cpu").item()
|
|
|
|
|
self.batched = True
|
|
|
|
|
try:
|
|
|
|
|
assert len(seed) == x.shape[0]
|
|
|
|
@ -412,7 +412,7 @@ def log_likelihood(
|
|
|
|
|
):
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
v = torch.randint_like(x, 2) * 2 - 1
|
|
|
|
|
v = torch.randint_like(x, 2, device="cpu") * 2 - 1
|
|
|
|
|
fevals = 0
|
|
|
|
|
|
|
|
|
|
def ode_fn(sigma, x):
|
|
|
|
|