|
|
|
@ -164,6 +164,8 @@ def fix_torch_group_norm():
|
|
|
|
|
|
|
|
|
|
def randn_seeded(seed: int, size: List[int]) -> Tensor:
|
|
|
|
|
"""Generate a random tensor with a given seed."""
|
|
|
|
|
from hashlib import md5
|
|
|
|
|
|
|
|
|
|
g_cpu = torch.Generator()
|
|
|
|
|
g_cpu.manual_seed(seed)
|
|
|
|
|
noise = torch.randn(
|
|
|
|
@ -171,6 +173,9 @@ def randn_seeded(seed: int, size: List[int]) -> Tensor:
|
|
|
|
|
device="cpu",
|
|
|
|
|
generator=g_cpu,
|
|
|
|
|
)
|
|
|
|
|
# md5 of the torch tensor `noise`
|
|
|
|
|
torch_md5 = md5(noise.numpy().tobytes()).hexdigest()
|
|
|
|
|
logger.debug(f"Made noise of size {size} from seed {seed}. md5:{torch_md5}")
|
|
|
|
|
return noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|