instructions to test distributed inference

8bit_blocks
justheuristic 2 years ago
parent 9be7c81b78
commit 2d55e6e4fe

@ -57,14 +57,15 @@ dht = hivemind.DHT(
client_mode=True, start=True,
)
m, = get_remote_module(dht, ['bloom6b3.0'])
layer0, layer1 = get_remote_module(dht, ['bloom6b3.0', 'bloom6b3.1'])
# test forward/backward, one block
outputs = m(torch.randn(1, 128, 4096))
# test forward/backward, two blocks
outputs, = layer1(*layer0(torch.randn(1, 64, 4096)))
loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
with m.begin_inference_session() as sess:
# test inference, one block
with layer0.begin_inference_session() as sess:
for i in range(10):
res = sess.step(torch.ones(1, 1, 4096))
```

Loading…
Cancel
Save