|
|
|
@ -53,22 +53,13 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
max_batch_size = self.forward_pool.max_batch_size
|
|
|
|
|
device = self.module.devices[self.module.output_device_index]
|
|
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
|
|
lambda args, kwargs: self.inference_step(*args, **kwargs),
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
device=device,
|
|
|
|
|
name=f"{self.name}_inference",
|
|
|
|
|
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
|
|
|
|
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
|
|
|
|
self.forward_pool = PrioritizedTaskPool(
|
|
|
|
|
lambda args, kwargs: self.forward(*args, **kwargs),
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
device=device,
|
|
|
|
|
name=f"{self.name}_forward",
|
|
|
|
|
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
|
|
|
|
)
|
|
|
|
|
self.backward_pool = PrioritizedTaskPool(
|
|
|
|
|
lambda args, kwargs: self.backward(*args, **kwargs),
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
device=device,
|
|
|
|
|
name=f"{self.name}_backward",
|
|
|
|
|
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.dtype = backend_dtype
|
|
|
|
|