From 6c7f7623799aaf527c5ac125b1b7c5dbb09e603a Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 05:21:12 +0300 Subject: [PATCH] rollback: only generic kwarg --- src/petals/server/backend.py | 15 +++------------ src/petals/server/server.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3d95bcc..9ae672b 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -53,22 +53,13 @@ class TransformerBackend(ModuleBackend): max_batch_size = self.forward_pool.max_batch_size device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - lambda args, kwargs: self.inference_step(*args, **kwargs), - max_batch_size=max_batch_size, - device=device, - name=f"{self.name}_inference", + self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( - lambda args, kwargs: self.forward(*args, **kwargs), - max_batch_size=max_batch_size, - device=device, - name=f"{self.name}_forward", + self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( - lambda args, kwargs: self.backward(*args, **kwargs), - max_batch_size=max_batch_size, - device=device, - name=f"{self.name}_backward", + self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) self.dtype = backend_dtype diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fd9f766..c85108a 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -770,3 +770,15 @@ class RuntimeWithDeduplicatedPools(Runtime): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pools = tuple(set(self.pools)) + + def process_batch( + self, pool: TaskPoolBase, batch_index: int, args: Sequence[Any], kwargs: Dict[str, Any] + ) -> Tuple[Any, int]: + """process one batch of tasks from a given pool, return a batch of results and total batch size""" + outputs = pool.process_func(*args, **kwargs) + batch_size = 1 + for arg in args: + if isintance(arg, torch.Tensor) and arg.ndim > 2: + batch_size = arg.shape[0] * arg.shape[1] + break + return outputs, batch_size