|
|
|
@ -152,6 +152,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
session_id = metadata.get("session_id")
|
|
|
|
|
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
|
|
|
|
|
args_structure = metadata.get("args_structure")
|
|
|
|
|
active_adapter = self._get_active_adapter(metadata)
|
|
|
|
|
if not requested_uids:
|
|
|
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
|
|
assert isinstance(
|
|
|
|
@ -169,12 +170,12 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
|
|
|
|
async with self._allocate_cache(
|
|
|
|
|
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
|
|
|
|
|
) as cache_handles:
|
|
|
|
|
) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
|
|
|
|
|
background_tasks = set()
|
|
|
|
|
async for output_tensors, can_push in iterate_rpc_inference(
|
|
|
|
|
requested_uids=requested_uids,
|
|
|
|
|
requested_backends=requested_backends,
|
|
|
|
|
active_adapter=self._get_active_adapter(metadata),
|
|
|
|
|
active_adapter=active_adapter,
|
|
|
|
|
input_iterator=self._iterate_inference_steps(
|
|
|
|
|
request, requests, session_id, requested_uids, context
|
|
|
|
|
),
|
|
|
|
@ -546,6 +547,33 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
|
|
|
|
|
yield nested_pack(handles, descriptors)
|
|
|
|
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
|
|
async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
|
|
|
|
|
if active_adapter == "":
|
|
|
|
|
yield
|
|
|
|
|
elif active_adapter in self.adapters:
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
_peft_module = backends[0]._peft_module
|
|
|
|
|
token = None # TODO: Provide token from user request maybe?
|
|
|
|
|
|
|
|
|
|
for backend in backends:
|
|
|
|
|
adapter_config, adapter_state_dict = _peft_module.load_peft(
|
|
|
|
|
active_adapter,
|
|
|
|
|
block_idx=backend.block_index,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=backend.cache_dir,
|
|
|
|
|
max_disk_space=backend.max_disk_space,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_peft_module.add_adapter_to_block(
|
|
|
|
|
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
|
|
|
|
|
|
|
|
|
|
def _log_request(
|
|
|
|
|
self,
|
|
|
|
|
method: str,
|
|
|
|
|