|
|
|
@ -15,6 +15,7 @@ from hivemind import (
|
|
|
|
|
MSGPackSerializer,
|
|
|
|
|
P2PContext,
|
|
|
|
|
PeerID,
|
|
|
|
|
TensorDescriptor,
|
|
|
|
|
deserialize_tensor_stream,
|
|
|
|
|
deserialize_torch_tensor,
|
|
|
|
|
nested_flatten,
|
|
|
|
@ -152,6 +153,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
session_id = metadata.get("session_id")
|
|
|
|
|
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
|
|
|
|
|
args_structure = metadata.get("args_structure")
|
|
|
|
|
active_adapter = self._get_active_adapter(metadata)
|
|
|
|
|
if not requested_uids:
|
|
|
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
|
|
assert isinstance(
|
|
|
|
@ -169,12 +171,14 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
|
|
|
|
async with self._allocate_cache(
|
|
|
|
|
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
|
|
|
|
|
) as cache_handles:
|
|
|
|
|
) as cache_handles, self._load_peft_module(
|
|
|
|
|
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
|
|
|
|
|
):
|
|
|
|
|
background_tasks = set()
|
|
|
|
|
async for output_tensors, can_push in iterate_rpc_inference(
|
|
|
|
|
requested_uids=requested_uids,
|
|
|
|
|
requested_backends=requested_backends,
|
|
|
|
|
active_adapter=self._get_active_adapter(metadata),
|
|
|
|
|
active_adapter=active_adapter,
|
|
|
|
|
input_iterator=self._iterate_inference_steps(
|
|
|
|
|
request, requests, session_id, requested_uids, context
|
|
|
|
|
),
|
|
|
|
@ -489,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
|
|
|
|
def _get_active_adapter(self, metadata: dict) -> str:
|
|
|
|
|
active_adapter = metadata.get("active_adapter", "")
|
|
|
|
|
if active_adapter and (active_adapter not in self.adapters):
|
|
|
|
|
raise KeyError(f"adapter {active_adapter} not found")
|
|
|
|
|
return active_adapter
|
|
|
|
|
if active_adapter:
|
|
|
|
|
return active_adapter
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
def _serialize_grads(
|
|
|
|
|
self,
|
|
|
|
@ -546,6 +550,51 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
|
|
|
|
|
yield nested_pack(handles, descriptors)
|
|
|
|
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
|
|
async def _load_peft_module(
|
|
|
|
|
self,
|
|
|
|
|
backends: Sequence[TransformerBackend],
|
|
|
|
|
*,
|
|
|
|
|
active_adapter: str,
|
|
|
|
|
timeout: float,
|
|
|
|
|
):
|
|
|
|
|
if active_adapter == "":
|
|
|
|
|
yield
|
|
|
|
|
elif active_adapter in self.adapters:
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
_peft_module = backends[0]._peft_module
|
|
|
|
|
token = None # TODO: Provide token from user request maybe?
|
|
|
|
|
|
|
|
|
|
estimated_peft_size = _peft_module.get_estimated_peft_module_size(
|
|
|
|
|
active_adapter,
|
|
|
|
|
token=token,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
fake_descriptor = TensorDescriptor(
|
|
|
|
|
size=(estimated_peft_size,),
|
|
|
|
|
dtype=torch.int8,
|
|
|
|
|
device=backends[0].device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
|
|
|
|
|
try:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
adapter_config, adapter_state_dict = _peft_module.load_peft(
|
|
|
|
|
active_adapter,
|
|
|
|
|
block_idx=backend.block_index,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=backend.cache_dir,
|
|
|
|
|
max_disk_space=backend.max_disk_space,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_peft_module.add_adapter_to_block(
|
|
|
|
|
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
|
|
|
|
|
|
|
|
|
|
def _log_request(
|
|
|
|
|
self,
|
|
|
|
|
method: str,
|
|
|
|
|