|
|
|
@ -57,14 +57,15 @@ dht = hivemind.DHT(
|
|
|
|
|
client_mode=True, start=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
m, = get_remote_module(dht, ['bloom6b3.0'])
|
|
|
|
|
layer0, layer1 = get_remote_module(dht, ['bloom6b3.0', 'bloom6b3.1'])
|
|
|
|
|
|
|
|
|
|
# test forward/backward, one block
|
|
|
|
|
outputs = m(torch.randn(1, 128, 4096))
|
|
|
|
|
# test forward/backward, two blocks
|
|
|
|
|
outputs, = layer1(*layer0(torch.randn(1, 64, 4096)))
|
|
|
|
|
loss = (outputs * torch.randn_like(outputs)).norm()
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
with m.begin_inference_session() as sess:
|
|
|
|
|
# test inference, one block
|
|
|
|
|
with layer0.begin_inference_session() as sess:
|
|
|
|
|
for i in range(10):
|
|
|
|
|
res = sess.step(torch.ones(1, 1, 4096))
|
|
|
|
|
```
|
|
|
|
|