|
|
|
@ -15,6 +15,7 @@ from hivemind import (
|
|
|
|
|
MSGPackSerializer,
|
|
|
|
|
P2PContext,
|
|
|
|
|
PeerID,
|
|
|
|
|
TensorDescriptor,
|
|
|
|
|
deserialize_tensor_stream,
|
|
|
|
|
deserialize_torch_tensor,
|
|
|
|
|
nested_flatten,
|
|
|
|
@ -170,7 +171,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
|
|
|
|
async with self._allocate_cache(
|
|
|
|
|
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
|
|
|
|
|
) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
|
|
|
|
|
) as cache_handles, self._load_peft_module(
|
|
|
|
|
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
|
|
|
|
|
):
|
|
|
|
|
background_tasks = set()
|
|
|
|
|
async for output_tensors, can_push in iterate_rpc_inference(
|
|
|
|
|
requested_uids=requested_uids,
|
|
|
|
@ -490,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
|
|
|
|
def _get_active_adapter(self, metadata: dict) -> str:
|
|
|
|
|
active_adapter = metadata.get("active_adapter", "")
|
|
|
|
|
if active_adapter and (active_adapter not in self.adapters):
|
|
|
|
|
raise KeyError(f"adapter {active_adapter} not found")
|
|
|
|
|
return active_adapter
|
|
|
|
|
if active_adapter:
|
|
|
|
|
return active_adapter
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
def _serialize_grads(
|
|
|
|
|
self,
|
|
|
|
@ -548,31 +551,49 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
yield nested_pack(handles, descriptors)
|
|
|
|
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
|
|
async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
|
|
|
|
|
async def _load_peft_module(
|
|
|
|
|
self,
|
|
|
|
|
backends: Sequence[TransformerBackend],
|
|
|
|
|
*,
|
|
|
|
|
active_adapter: str,
|
|
|
|
|
timeout: float,
|
|
|
|
|
):
|
|
|
|
|
if active_adapter == "":
|
|
|
|
|
yield
|
|
|
|
|
elif active_adapter in self.adapters:
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
_peft_module = backends[0]._peft_module
|
|
|
|
|
token = None # TODO: Provide token from user request maybe?
|
|
|
|
|
|
|
|
|
|
for backend in backends:
|
|
|
|
|
adapter_config, adapter_state_dict = _peft_module.load_peft(
|
|
|
|
|
active_adapter,
|
|
|
|
|
block_idx=backend.block_index,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=backend.cache_dir,
|
|
|
|
|
max_disk_space=backend.max_disk_space,
|
|
|
|
|
)
|
|
|
|
|
_peft_module = backends[0]._peft_module
|
|
|
|
|
token = None # TODO: Provide token from user request maybe?
|
|
|
|
|
|
|
|
|
|
_peft_module.add_adapter_to_block(
|
|
|
|
|
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
|
|
|
|
|
estimated_peft_size = _peft_module.get_estimated_peft_module_size(
|
|
|
|
|
active_adapter,
|
|
|
|
|
token=token,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
fake_descriptor = TensorDescriptor(
|
|
|
|
|
size=(estimated_peft_size,),
|
|
|
|
|
dtype=torch.int8,
|
|
|
|
|
device=backends[0].device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
|
|
|
|
|
try:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
adapter_config, adapter_state_dict = _peft_module.load_peft(
|
|
|
|
|
active_adapter,
|
|
|
|
|
block_idx=backend.block_index,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=backend.cache_dir,
|
|
|
|
|
max_disk_space=backend.max_disk_space,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_peft_module.add_adapter_to_block(
|
|
|
|
|
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
|
|
|
|
|
|
|
|
|
|
def _log_request(
|
|
|
|
|
self,
|
|
|
|
|