delete DustyBlock, cosmetic changes
parent
8b845fdd76
commit
a7395fe27c
@ -1,72 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import AsyncIterator, Callable, Optional
|
||||
|
||||
from hivemind.moe.client import RemoteExpert
|
||||
from hivemind.moe.expert_uid import ExpertInfo
|
||||
from hivemind.p2p import P2P, StubBase
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils import MSGPackSerializer, amap_in_executor
|
||||
|
||||
from src.client.spending_policy import SpendingPolicyBase
|
||||
|
||||
|
||||
# TODO: (greenfatguy) remove later, left for now as example
|
||||
class DustyRemoteBlock(RemoteExpert):
|
||||
def __init__(self, bank: SpendingPolicyBase, expert_info: ExpertInfo, p2p: P2P):
|
||||
self._spending_policy = bank
|
||||
super().__init__(expert_info, p2p)
|
||||
|
||||
def _unary_request_wrapper(self, rpc_call: Callable, rpc_name: str):
|
||||
@wraps(rpc_call)
|
||||
async def rpc(input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
|
||||
meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
|
||||
meta["__dust"] = self._spending_policy.get_points(input, rpc_name)
|
||||
input.metadata = MSGPackSerializer.dumps(meta)
|
||||
return await rpc_call(input, timeout)
|
||||
|
||||
return rpc
|
||||
|
||||
def _stream_request_wrapper(self, rpc_call: Callable, rpc_name: str):
|
||||
@wraps(rpc_call)
|
||||
async def rpc(input: AsyncIterator[runtime_pb2.ExpertRequest], timeout: Optional[float] = None):
|
||||
is_meta_set = False
|
||||
|
||||
def _metadata_setter(chunk: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertRequest:
|
||||
nonlocal is_meta_set
|
||||
if not is_meta_set:
|
||||
meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
|
||||
meta["__dust"] = self._spending_policy.get_points(chunk, rpc_name)
|
||||
chunk.metadata = MSGPackSerializer.dumps(meta)
|
||||
is_meta_set = True
|
||||
return chunk
|
||||
|
||||
return await rpc_call(amap_in_executor(_metadata_setter, input), timeout)
|
||||
|
||||
return rpc
|
||||
|
||||
def _prioritize_handler_stub_calls(self, stub: StubBase) -> StubBase:
|
||||
for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
|
||||
if name.startswith("rpc"):
|
||||
spec = inspect.getfullargspec(method)
|
||||
# rpc callers has 3 arguments: stub, input and timeout
|
||||
if len(spec.args) != 3:
|
||||
continue
|
||||
|
||||
input_type = spec.annotations[spec.args[1]]
|
||||
|
||||
if input_type is AsyncIterator[runtime_pb2.ExpertRequest]:
|
||||
setattr(stub, name, self._stream_request_wrapper(method, name))
|
||||
elif input_type is runtime_pb2.ExpertRequest:
|
||||
setattr(stub, name, self._unary_request_wrapper(method, name))
|
||||
return stub
|
||||
|
||||
@property
|
||||
def _stub(self) -> StubBase:
|
||||
return super().stub
|
||||
|
||||
@property
|
||||
def stub(self) -> StubBase:
|
||||
return self._prioritize_handler_stub_calls(self._stub)
|
@ -1,99 +0,0 @@
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
|
||||
from hivemind.proto.runtime_pb2 import ExpertRequest
|
||||
from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
|
||||
|
||||
from src.client.priority_block import DustyRemoteBlock
|
||||
from src.client.spending_policy import SpendingPolicyBase
|
||||
|
||||
|
||||
class SpendingPolicyTest(SpendingPolicyBase):
|
||||
def __init__(self):
|
||||
self._p = {
|
||||
"rpc_single": 1,
|
||||
"rpc_stream": 2,
|
||||
}
|
||||
|
||||
def get_points(self, request: ExpertRequest, method_name: str) -> float:
|
||||
return self._p.get(method_name, -1)
|
||||
|
||||
|
||||
class HandlerStubTest:
|
||||
async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
|
||||
return input
|
||||
|
||||
async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
|
||||
return input
|
||||
|
||||
async def rpc_info(self, input: str, timeout: Optional[float] = None):
|
||||
return input
|
||||
|
||||
|
||||
class RemoteBlockTest(DustyRemoteBlock):
|
||||
@property
|
||||
def _stub(self):
|
||||
return HandlerStubTest()
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.asyncio
|
||||
async def test_single():
|
||||
remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
|
||||
stub = remote.stub
|
||||
input = torch.randn(1, 2)
|
||||
request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
|
||||
|
||||
print(stub)
|
||||
out: ExpertRequest = await stub.rpc_single(request)
|
||||
|
||||
assert out.metadata != b""
|
||||
assert len(out.tensors) == 1
|
||||
assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
|
||||
|
||||
meta = MSGPackSerializer.loads(out.metadata)
|
||||
assert isinstance(meta, dict)
|
||||
assert "__dust" in meta
|
||||
assert meta["__dust"] == 1
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream():
|
||||
remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
|
||||
stub = remote.stub
|
||||
input = torch.randn(2**21, 2)
|
||||
|
||||
split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
|
||||
output_generator = await stub.rpc_stream(
|
||||
amap_in_executor(
|
||||
lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
|
||||
iter_as_aiter(split),
|
||||
),
|
||||
)
|
||||
outputs_list = [part async for part in output_generator]
|
||||
assert len(outputs_list) == 2**5 * 8
|
||||
assert outputs_list[0].metadata != b""
|
||||
for i in range(1, len(outputs_list)):
|
||||
assert outputs_list[i].metadata == b""
|
||||
|
||||
meta = MSGPackSerializer.loads(outputs_list[0].metadata)
|
||||
assert isinstance(meta, dict)
|
||||
assert "__dust" in meta
|
||||
assert meta["__dust"] == 2
|
||||
|
||||
results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
|
||||
assert len(results) == 1
|
||||
assert torch.allclose(results[0], input)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_wrapper():
|
||||
remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
|
||||
stub = remote.stub
|
||||
|
||||
test = await stub.rpc_info("Test")
|
||||
assert test == "Test"
|
Loading…
Reference in New Issue