delete DustyBlock, cosmetic changes

pull/47/head
Pavel Samygin 2 years ago
parent 8b845fdd76
commit a7395fe27c

@ -1,5 +1,4 @@
from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
from src.client.priority_block import DustyRemoteBlock
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager

@ -1,72 +0,0 @@
from __future__ import annotations
import inspect
from functools import wraps
from typing import AsyncIterator, Callable, Optional
from hivemind.moe.client import RemoteExpert
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import MSGPackSerializer, amap_in_executor
from src.client.spending_policy import SpendingPolicyBase
# TODO: (greenfatguy) remove later, left for now as example
class DustyRemoteBlock(RemoteExpert):
def __init__(self, bank: SpendingPolicyBase, expert_info: ExpertInfo, p2p: P2P):
self._spending_policy = bank
super().__init__(expert_info, p2p)
def _unary_request_wrapper(self, rpc_call: Callable, rpc_name: str):
@wraps(rpc_call)
async def rpc(input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
meta["__dust"] = self._spending_policy.get_points(input, rpc_name)
input.metadata = MSGPackSerializer.dumps(meta)
return await rpc_call(input, timeout)
return rpc
def _stream_request_wrapper(self, rpc_call: Callable, rpc_name: str):
@wraps(rpc_call)
async def rpc(input: AsyncIterator[runtime_pb2.ExpertRequest], timeout: Optional[float] = None):
is_meta_set = False
def _metadata_setter(chunk: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertRequest:
nonlocal is_meta_set
if not is_meta_set:
meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
meta["__dust"] = self._spending_policy.get_points(chunk, rpc_name)
chunk.metadata = MSGPackSerializer.dumps(meta)
is_meta_set = True
return chunk
return await rpc_call(amap_in_executor(_metadata_setter, input), timeout)
return rpc
def _prioritize_handler_stub_calls(self, stub: StubBase) -> StubBase:
for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
if name.startswith("rpc"):
spec = inspect.getfullargspec(method)
# rpc callers has 3 arguments: stub, input and timeout
if len(spec.args) != 3:
continue
input_type = spec.annotations[spec.args[1]]
if input_type is AsyncIterator[runtime_pb2.ExpertRequest]:
setattr(stub, name, self._stream_request_wrapper(method, name))
elif input_type is runtime_pb2.ExpertRequest:
setattr(stub, name, self._unary_request_wrapper(method, name))
return stub
@property
def _stub(self) -> StubBase:
return super().stub
@property
def stub(self) -> StubBase:
return self._prioritize_handler_stub_calls(self._stub)

@ -138,7 +138,12 @@ class TransformerConnectionHandler(ConnectionHandler):
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
cache_metadata,
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
backend=backend,
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids, priority=priority

@ -5,16 +5,16 @@ from hivemind.moe.server.task_pool import Task
class TaskPrioritizerBase(ABC):
"""Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
"""Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
@abstractmethod
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
"""Evaluates task value by the amout of points given"""
"""Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
pass
class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of DustBroker which counts amount of dust per task size"""
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
return 0.0

@ -1,99 +0,0 @@
from typing import AsyncIterator, Optional
import pytest
import torch
from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
from hivemind.proto.runtime_pb2 import ExpertRequest
from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
from src.client.priority_block import DustyRemoteBlock
from src.client.spending_policy import SpendingPolicyBase
class SpendingPolicyTest(SpendingPolicyBase):
def __init__(self):
self._p = {
"rpc_single": 1,
"rpc_stream": 2,
}
def get_points(self, request: ExpertRequest, method_name: str) -> float:
return self._p.get(method_name, -1)
class HandlerStubTest:
async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
return input
async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
return input
async def rpc_info(self, input: str, timeout: Optional[float] = None):
return input
class RemoteBlockTest(DustyRemoteBlock):
@property
def _stub(self):
return HandlerStubTest()
@pytest.mark.forked
@pytest.mark.asyncio
async def test_single():
remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
stub = remote.stub
input = torch.randn(1, 2)
request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
print(stub)
out: ExpertRequest = await stub.rpc_single(request)
assert out.metadata != b""
assert len(out.tensors) == 1
assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
meta = MSGPackSerializer.loads(out.metadata)
assert isinstance(meta, dict)
assert "__dust" in meta
assert meta["__dust"] == 1
@pytest.mark.forked
@pytest.mark.asyncio
async def test_stream():
remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
stub = remote.stub
input = torch.randn(2**21, 2)
split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
output_generator = await stub.rpc_stream(
amap_in_executor(
lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
iter_as_aiter(split),
),
)
outputs_list = [part async for part in output_generator]
assert len(outputs_list) == 2**5 * 8
assert outputs_list[0].metadata != b""
for i in range(1, len(outputs_list)):
assert outputs_list[i].metadata == b""
meta = MSGPackSerializer.loads(outputs_list[0].metadata)
assert isinstance(meta, dict)
assert "__dust" in meta
assert meta["__dust"] == 2
results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
assert len(results) == 1
assert torch.allclose(results[0], input)
@pytest.mark.forked
@pytest.mark.asyncio
async def test_no_wrapper():
remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
stub = remote.stub
test = await stub.rpc_info("Test")
assert test == "Test"
Loading…
Cancel
Save