mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
wip some more
This commit is contained in:
parent
e5c2d8eca4
commit
49474e5477
@ -4,7 +4,7 @@ import asyncio
|
||||
import itertools
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Tuple
|
||||
from typing import AsyncIterator, List, Optional, Tuple, Sequence
|
||||
|
||||
import torch
|
||||
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
|
||||
@ -34,7 +34,7 @@ class _ServerInferenceSession:
|
||||
self,
|
||||
config: ClientConfig,
|
||||
span: RemoteSpanInfo,
|
||||
uid: ModuleUID,
|
||||
span_uids: Sequence[ModuleUID],
|
||||
rpc_info: RPCInfo,
|
||||
inputs_queue: asyncio.Queue,
|
||||
outputs_aiter: AsyncIterator,
|
||||
@ -43,8 +43,8 @@ class _ServerInferenceSession:
|
||||
**metadata,
|
||||
):
|
||||
self.config = config
|
||||
self.span, self.uid, self.rpc_info = span, uid, rpc_info
|
||||
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
||||
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
|
||||
self.num_blocks = len(span_uids)
|
||||
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
||||
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
||||
self.session_id = str(uuid.uuid4())
|
||||
@ -62,18 +62,19 @@ class _ServerInferenceSession:
|
||||
config: ClientConfig,
|
||||
p2p: P2P,
|
||||
span: RemoteSpanInfo,
|
||||
uid: ModuleUID,
|
||||
span_uids: Sequence[RemoteSpanInfo],
|
||||
rpc_info: RPCInfo,
|
||||
**metadata,
|
||||
) -> _ServerInferenceSession:
|
||||
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
||||
# TODO YOZH you don't need rpc info here
|
||||
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
||||
inputs_queue = asyncio.Queue()
|
||||
outputs_stream = await asyncio.wait_for(
|
||||
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
||||
config.connect_timeout,
|
||||
)
|
||||
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
|
||||
return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
|
||||
|
||||
@staticmethod
|
||||
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
||||
@ -142,6 +143,7 @@ class _ServerInferenceSession:
|
||||
|
||||
request_metadata["args_structure"] = args_structure
|
||||
|
||||
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
|
||||
# TODO: make possible to use different compression method for different tensors
|
||||
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
|
||||
compression = server_side_inference_schema[0].compression
|
||||
@ -155,7 +157,7 @@ class _ServerInferenceSession:
|
||||
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
||||
self._step(
|
||||
runtime_pb2.ExpertRequest(
|
||||
uid=self.uid,
|
||||
uid=CHAIN_DELIMITER.join(self.span_uids),
|
||||
tensors=[
|
||||
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
||||
for tensor, proto in zip(input_tensors, inference_schema)
|
||||
@ -244,8 +246,8 @@ class InferenceSession:
|
||||
server_sessions = []
|
||||
try:
|
||||
for span in chosen_spans:
|
||||
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
||||
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
|
||||
span_uids = self._sequence_manager.block_uids[span.start : span.end]
|
||||
metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
|
||||
session = RemoteExpertWorker.run_coroutine(
|
||||
_ServerInferenceSession.create(
|
||||
self._sequence_manager.config,
|
||||
|
@ -2,50 +2,49 @@
|
||||
Utility functions that call RPC forward or backward on a single remote server
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import Iterable, List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor, PeerID
|
||||
from hivemind import PeerID, nested_flatten, serialize_torch_tensor
|
||||
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
|
||||
from hivemind.p2p import StubBase
|
||||
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
|
||||
from hivemind.utils.serializer import MSGPackSerializer
|
||||
from hivemind.utils.streaming import split_for_streaming
|
||||
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
||||
|
||||
from petals import RemoteSequenceManager
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.data_structures import ModuleUID, RPCInfo, CHAIN_DELIMITER
|
||||
from petals.client.routing import RemoteSequenceManager
|
||||
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
from petals.utils.packaging import pack_args_kwargs
|
||||
|
||||
|
||||
async def _forward_unary(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
||||
timeout=config.request_timeout,
|
||||
)
|
||||
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
||||
|
||||
|
||||
async def _backward_unary(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
||||
timeout=config.request_timeout,
|
||||
)
|
||||
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
||||
|
||||
|
||||
async def _forward_stream(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
parts = (
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
||||
for tensor in serialized_tensors
|
||||
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
||||
)
|
||||
@ -55,10 +54,10 @@ async def _forward_stream(
|
||||
|
||||
|
||||
async def _backward_stream(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
parts = (
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
||||
for tensor in serialized_tensors
|
||||
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
||||
)
|
||||
@ -81,35 +80,39 @@ async def run_remote_forward(
|
||||
"""
|
||||
merged_uid = CHAIN_DELIMITER.join(span_uids)
|
||||
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
|
||||
flat_inputs, args_structure = pack_args_kwargs(*args, **kwargs)
|
||||
metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs)
|
||||
compressions = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
|
||||
if compressions is None:
|
||||
compressions = [runtime_pb2.CompressionType.NONE] * len(flat_inputs)
|
||||
compressions = list(nested_flatten(compressions))
|
||||
assert len(compressions) == len(flat_inputs), f"got {len(flat_inputs)} tensors but {len(compressions)} codecs"
|
||||
inputs = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_inputs)
|
||||
codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
|
||||
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
|
||||
flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
|
||||
args_structure = metadata.setdefault("args_structure", args_structure)
|
||||
if codecs is None:
|
||||
codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
|
||||
else:
|
||||
codecs = list(nested_flatten(codecs))
|
||||
assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
|
||||
|
||||
# Asynchronous serialization
|
||||
loop = asyncio.get_running_loop()
|
||||
serialized_tensors = await asyncio.gather(
|
||||
*(
|
||||
loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
|
||||
for tensor, compression in zip(inputs, compressions)
|
||||
for tensor, compression in zip(flat_tensors, codecs)
|
||||
)
|
||||
)
|
||||
|
||||
# call RPC on remote server
|
||||
size = sum(t.element_size() * t.nelement() for t in inputs)
|
||||
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
|
||||
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
|
||||
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
|
||||
return await forward_fn(merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata)
|
||||
return await forward_fn(
|
||||
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
|
||||
)
|
||||
|
||||
|
||||
async def run_remote_backward(
|
||||
sequence_manager: RemoteSequenceManager,
|
||||
peer_id: PeerID,
|
||||
span_uids: Sequence[ModuleUID],
|
||||
stub: StubBase,
|
||||
grad_outputs: Sequence[torch.Tensor],
|
||||
*args: torch.Tensor,
|
||||
**kwargs: torch.Tensor,
|
||||
@ -119,23 +122,32 @@ async def run_remote_backward(
|
||||
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
|
||||
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
||||
"""
|
||||
flat_tensors, args_structure = pack_args_kwargs(
|
||||
[grad.cpu() for grad in grad_outputs], args, kwargs
|
||||
)
|
||||
metadata = sequence_manager.get_request_metadata(
|
||||
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
|
||||
)
|
||||
merged_uid = CHAIN_DELIMITER.join(span_uids)
|
||||
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
|
||||
metadata = sequence_manager.get_request_metadata(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
|
||||
codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
|
||||
flat_tensors, args_structure = pack_args_kwargs(grad_outputs, *args, **kwargs)
|
||||
flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
|
||||
args_structure = metadata.setdefault("args_structure", args_structure)
|
||||
|
||||
if codecs is None:
|
||||
codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
|
||||
else:
|
||||
codecs = list(nested_flatten(codecs))
|
||||
assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
|
||||
|
||||
# Asynchronous serialization
|
||||
loop = asyncio.get_running_loop()
|
||||
serialized_tensors = await asyncio.gather(
|
||||
*(
|
||||
loop.run_in_executor(None, serialize_torch_tensor, compression)
|
||||
for tensor, proto in zip(flat_inputs_and_grad_outputs, backward_schema)
|
||||
loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
|
||||
for tensor, compression in zip(flat_tensors, codecs)
|
||||
)
|
||||
)
|
||||
|
||||
size = sum(t.element_size() * t.nelement() for t in flat_inputs_and_grad_outputs)
|
||||
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
|
||||
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
|
||||
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
|
||||
return await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
|
||||
return await backward_fn(
|
||||
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
|
||||
)
|
||||
|
@ -474,7 +474,9 @@ class RemoteSequenceManager:
|
||||
return 0
|
||||
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
|
||||
|
||||
def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Dict[str, Any]]:
|
||||
def get_request_metadata(
|
||||
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
:param peer_id: remote server's PeerID
|
||||
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
|
||||
@ -488,7 +490,8 @@ class RemoteSequenceManager:
|
||||
)
|
||||
|
||||
def get_compression_codecs(
|
||||
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
|
||||
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
|
||||
) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
|
||||
"""
|
||||
:param peer_id: remote server's PeerID
|
||||
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
|
||||
|
@ -4,19 +4,16 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
|
||||
import asyncio
|
||||
import itertools
|
||||
from collections import deque
|
||||
from typing import List, Optional, Sequence, Tuple, Dict, Any
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import MSGPackSerializer
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
||||
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
|
||||
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
from petals.data_structures import RemoteSpanInfo
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
from petals.utils.packaging import pack_args_kwargs
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -24,12 +21,12 @@ MAX_TOKENS_IN_BATCH = 1024
|
||||
|
||||
|
||||
async def sequential_forward(
|
||||
sequence_manager: RemoteSequenceManager,
|
||||
inputs: torch.Tensor,
|
||||
prompts: torch.Tensor,
|
||||
sequence_manager: RemoteSequenceManager,
|
||||
start_index: int = 0,
|
||||
end_index: Optional[int] = None,
|
||||
block_kwargs: Sequence[Dict[str, Any]] = (),
|
||||
*block_kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
||||
"""
|
||||
Constructs a routing path from <start_index> to <end_index>.
|
||||
@ -45,13 +42,6 @@ async def sequential_forward(
|
||||
:param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
|
||||
"""
|
||||
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
||||
assert len(block_kwargs) in (
|
||||
0,
|
||||
1,
|
||||
end_index - start_index,
|
||||
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
||||
|
||||
inputs_device = inputs.device
|
||||
inputs_dtype = inputs.dtype
|
||||
inputs = inputs.cpu()
|
||||
@ -59,6 +49,9 @@ async def sequential_forward(
|
||||
|
||||
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
||||
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
||||
assert len(block_kwargs) in (0, 1, end_index - start_index), \
|
||||
f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
||||
assert is_dummy(prompts) or len(prompts) == len(
|
||||
sequence_manager.block_uids
|
||||
) # should be n_layers - 1 but add extra prompts for convenience
|
||||
@ -87,13 +80,13 @@ async def sequential_forward(
|
||||
sequence_manager.block_uids[span.start : span.end],
|
||||
inputs,
|
||||
prompts[span.start : span.end],
|
||||
*block_kwargs[span.start : span.end]
|
||||
*block_kwargs[span.start : span.end],
|
||||
)
|
||||
|
||||
assert isinstance(outputs, torch.Tensor)
|
||||
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
|
||||
|
||||
# Save intermediate inputs and subsequ_peerences if the forward is already done for them
|
||||
# Save intermediate inputs and subsequences if the forward is already done for them
|
||||
intermediate_inputs.append(inputs)
|
||||
done_sequences.append(span)
|
||||
|
||||
@ -118,11 +111,12 @@ async def sequential_forward(
|
||||
|
||||
|
||||
async def sequential_backward(
|
||||
sequence_manager: RemoteSequenceManager,
|
||||
forward_sequences: List[RemoteSpanInfo],
|
||||
grad_outputs: Sequence[torch.Tensor],
|
||||
intermediate_inputs: List[torch.Tensor],
|
||||
prompts: torch.Tensor,
|
||||
forward_sequences: List[RemoteSpanInfo],
|
||||
sequence_manager: RemoteSequenceManager,
|
||||
*block_kwargs: Dict[str, Any],
|
||||
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Performs chained backward for each forward subsequence.
|
||||
@ -148,7 +142,7 @@ async def sequential_backward(
|
||||
try:
|
||||
if attempt_no >= 1:
|
||||
_, backup_inputs, backup_sequences = await sequential_forward(
|
||||
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
||||
sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end
|
||||
)
|
||||
assert len(backup_inputs) == len(backup_sequences)
|
||||
assert backup_sequences[0].start == span.start
|
||||
@ -159,14 +153,13 @@ async def sequential_backward(
|
||||
inputs = intermediate_inputs.pop()
|
||||
span = forward_sequences.pop()
|
||||
|
||||
|
||||
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
||||
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
|
||||
grad_outputs, *span_grad_prompts = await run_remote_backward(
|
||||
sequence_manager,
|
||||
sequence_manager.block_uids[span.start: span.end],
|
||||
span_uids,
|
||||
grad_outputs, inputs,
|
||||
span.peer_id,
|
||||
sequence_manager.block_uids[span.start : span.end],
|
||||
grad_outputs,
|
||||
*inputs,
|
||||
*block_kwargs[span.start : span.end],
|
||||
)
|
||||
grad_outputs = [grad_outputs]
|
||||
grad_prompts_reversed.extend(span_grad_prompts)
|
||||
@ -198,7 +191,7 @@ async def _gather_forward(input_batches, prompt_batches, sequence_manager):
|
||||
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
sequential_forward(input_batch, prompt_batch, sequence_manager)
|
||||
sequential_forward(sequence_manager, input_batch, prompt_batch)
|
||||
for input_batch, prompt_batch in zip(input_batches, prompt_batches)
|
||||
]
|
||||
)
|
||||
@ -210,7 +203,7 @@ async def _gather_backward(
|
||||
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
|
||||
sequential_backward(sequence_manager, spans, (grad_output,), input_batch, prompt_batch)
|
||||
for grad_output, input_batch, prompt_batch, spans in zip(
|
||||
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user