rollback: only generic kwarg

partial_rollback
Your Name 8 months ago
parent 6256995bb1
commit 6c7f762379

@ -53,22 +53,13 @@ class TransformerBackend(ModuleBackend):
max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
lambda args, kwargs: self.inference_step(*args, **kwargs),
max_batch_size=max_batch_size,
device=device,
name=f"{self.name}_inference",
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
lambda args, kwargs: self.forward(*args, **kwargs),
max_batch_size=max_batch_size,
device=device,
name=f"{self.name}_forward",
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
lambda args, kwargs: self.backward(*args, **kwargs),
max_batch_size=max_batch_size,
device=device,
name=f"{self.name}_backward",
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
)
self.dtype = backend_dtype

@ -770,3 +770,15 @@ class RuntimeWithDeduplicatedPools(Runtime):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pools = tuple(set(self.pools))
def process_batch(
self, pool: TaskPoolBase, batch_index: int, args: Sequence[Any], kwargs: Dict[str, Any]
) -> Tuple[Any, int]:
"""process one batch of tasks from a given pool, return a batch of results and total batch size"""
outputs = pool.process_func(*args, **kwargs)
batch_size = 1
for arg in args:
if isintance(arg, torch.Tensor) and arg.ndim > 2:
batch_size = arg.shape[0] * arg.shape[1]
break
return outputs, batch_size

Loading…
Cancel
Save