Add deep prompt inference (#66)

Add deep prompt in inference_step. Small refactoring in deep prompt code.
pull/60/head^2
Artem Chumachenko 2 years ago committed by GitHub
parent 54ad745bed
commit ada98a1b37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
max_length: int,
):
self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession:
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
def step(self, new_hidden_states: torch.Tensor):
"""Inference step: send a chunk of input tensors and receive a chunk of outputs"""
def step(
self,
new_hidden_states: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
):
"""
Inference step: send a chunk of input tesors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
assert prompts.shape[2] <= new_hidden_states.shape[1]
assert prompts.shape[3] == new_hidden_states.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
else:
assert len(hypo_ids) == len(new_hidden_states)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
inputs = (new_hidden_states,)
inputs = (new_hidden_states, prompts, hypo_ids)
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
],
metadata=self._serialized_metadata if not self.stepped else None,
)
@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession:
return self
def step(self, inputs: torch.Tensor):
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
assert not self.closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
for session in self.inference_sessions:
outputs = session.step(inputs)
outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs

@ -105,11 +105,12 @@ class RemoteGenerationMixin:
hypo_ids = torch.arange(outputs[0].size(0))
while True:
embs = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, _ = self.transformer.get_prompt(embs.size(0))
prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
embs = torch.cat([prompts, embs], dim=1)
embs = self.transformer.word_embeddings_layernorm(embs)
hidden_state = sess.step(embs)[:, -1]
hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
hidden_state = self.transformer.ln_f(hidden_state)
lm_logits = self.lm_head(hidden_state)

@ -1,15 +1,16 @@
"""Code for serving bloom blocks via hivemind-server"""
from queue import Empty
from typing import Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
from hivemind import use_hivemind_log_handler
from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPool
from hivemind.utils import InvalidStateError, get_logger
from src.bloom.from_pretrained import BloomBlock
from src.server.cache import MemoryCache
from src.utils.misc import is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -55,18 +56,28 @@ class TransformerBackend(ModuleBackend):
self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
)
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
self.inference_schema = (
(
*self.args_schema,
BatchTensorDescriptor((), dtype=self.dtype),
BatchTensorDescriptor((), dtype=torch.int64),
),
self.kwargs_schema,
)
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
with torch.inference_mode():
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
(hidden_states, hypo_ids) = inputs
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
if not is_dummy(hypo_ids):
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(
@ -85,3 +96,7 @@ class TransformerBackend(ModuleBackend):
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool
def get_info(self) -> Dict[str, Any]:
"""Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)

@ -64,41 +64,56 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq)
hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
# run request tensors through all requested modules, update caches
for backend, cache_handle in zip(requested_backends, cache_handles):
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
assert isinstance(hidden_states, (list, tuple))
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids
)
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
)
# prepare for next step
prefix_length += hidden_states[0].shape[1]
prefix_length += hidden_states.shape[1]
request = await (anext(requests))
finally:
print("CLOSED RPC_INFERENCE")
@ -238,23 +253,20 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, *prompts = flat_tensors
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if not prompts or is_dummy(prompts[0]):
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
pre_seq_len = 0
else:
prompts = [prompts[0].to(requested_backends[0].dtype)]
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
pre_seq_len = prompts[0].shape[-2]
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, :pre_seq_len] += prompt
hidden_states[:, : prompt.shape[1]] += prompt
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
assert isinstance(hidden_states, torch.Tensor)
assert (
@ -268,18 +280,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
async def _rpc_backward(
*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, *prompts = flat_tensors
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if not prompts or is_dummy(prompts[0]):
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
pre_seq_len = 0
else:
prompts = [prompts[0].to(requested_backends[0].dtype)]
prompts = [p.squeeze(0) for p in prompts[0].split(1)]
pre_seq_len = prompts[0].shape[-2]
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
@ -287,13 +296,13 @@ async def _rpc_backward(
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, :pre_seq_len] += prompt
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
(inputs,) = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, :pre_seq_len] += prompts[-1]
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
@ -303,7 +312,7 @@ async def _rpc_backward(
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

Loading…
Cancel
Save