Pack of Inference Changes (#37)

* Return multibatch mode

* Add tests

* fixes
petals-readme-title
Artem Chumachenko 2 years ago committed by GitHub
parent 6573076883
commit d989b94614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,6 +17,7 @@ class RemoteGenerationMixin:
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
"""
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
@ -27,6 +28,7 @@ class RemoteGenerationMixin:
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
@ -63,6 +65,10 @@ class RemoteGenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - inputs.size(1)
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
if inputs is None:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
inputs = torch.tensor([[bos_token_id]])

@ -17,8 +17,6 @@ from src.bloom.model import (
)
from src.client.remote_generation import RemoteGenerationMixin
from src.client.remote_sequential import RemoteSequential
from src.utils.generation_algorithms import DecodingAlgorithm
from src.utils.generation_constraints import ABCBloomConstraint
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -156,7 +154,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
return transformer_outputs
class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
config_class = DistributedBloomConfig
@ -190,33 +188,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
self.lm_head.bias[...] = new_lm_head.bias
def generate(
self,
inputs: Optional[torch.Tensor] = None,
do_sample: Optional[bool] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
eos_token_id: Optional[int] = None,
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.Tensor:
return RemoteGenerationMixin.generate(
self,
inputs=inputs,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=eos_token_id,
max_new_tokens=max_new_tokens,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
**model_kwargs,
)
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
config_class = DistributedBloomConfig

@ -1,16 +1,46 @@
"""Code for serving bloom blocks via hivemind-server"""
from queue import Empty
from typing import Sequence, Tuple
import torch
from hivemind import use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPool
from hivemind.utils import InvalidStateError, get_logger
from src.bloom.from_pretrained import BloomBlock
from src.server.cache import MemoryCache
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
MAX_LENGTH = 2048
class InferenceTaskPool(TaskPool):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
def iterate_minibatches(self, *args, **kwargs):
"""Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
while True:
try:
logger.debug(f"{self.name} getting next task")
task = self.tasks.get(timeout=self.timeout)
except Empty:
logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
continue
try:
if task.future.set_running_or_notify_cancel():
yield [task]
except InvalidStateError as e:
logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
class TransformerBackend(ModuleBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
@ -23,7 +53,9 @@ class TransformerBackend(ModuleBackend):
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
self.inference_pool = InferenceTaskPool(
self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
)
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
with torch.inference_mode():

@ -4,6 +4,7 @@ import transformers
from hivemind import get_logger, use_hivemind_log_handler
from test_utils import *
from src.bloom.model import BloomForCausalLM
from src.client.remote_model import DistributedBloomForCausalLM
use_hivemind_log_handler("in_root_logger")
@ -54,3 +55,32 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False
@pytest.mark.forked
def test_greedy_generation(max_new_tokens=4):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
)
hf_outputs_batch = BloomForCausalLM.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search are not identical to HF in multibatch mode"

Loading…
Cancel
Save