From 2ad0b2b936629da2076d639da2c1200468ea568e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 18 Mar 2024 00:11:47 +0300 Subject: [PATCH] Fix p2p pushing in rpc_inference (by @miaoqijun ) , support transformers 4.38.2 (#563) This pull request solves #560 using a solution proposed by @miaoqijun . It also bumps transformers to the latest version to test with the latest code. --------- Co-authored-by: Yingtong Dou --- setup.cfg | 2 +- src/petals/__init__.py | 4 ++-- src/petals/models/llama/block.py | 17 +++++++++++++---- src/petals/models/llama/model.py | 3 +++ src/petals/server/block_functions.py | 4 ++-- src/petals/server/handler.py | 4 ++-- 6 files changed, 23 insertions(+), 11 deletions(-) diff --git a/setup.cfg b/setup.cfg index b06dd5c..dc0bd4e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = accelerate>=0.27.2 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py + transformers==4.38.2 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet hivemind==1.1.10.post2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index fd38936..ccc560e 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -22,8 +22,8 @@ __version__ = "2.3.0.dev2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0") - ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0" + version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0") + ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0" def _override_bfloat16_mode_default(): diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 6f539a8..2eb8f73 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -50,9 +50,15 @@ class OptimizedLlamaAttention(LlamaAttention): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert not output_attentions - assert position_ids is None + if position_ids is None: + past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0 + position_ids = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ).unsqueeze(0) + bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -84,9 +90,8 @@ class OptimizedLlamaAttention(LlamaAttention): kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos[kv_seq_len - q_len :] - sin = sin[kv_seq_len - q_len :] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + cos, sin = cos.unsqueeze(1), sin.unsqueeze(1) if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) @@ -160,6 +165,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -190,6 +197,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index 611bb2b..1257cd7 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -47,6 +47,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> BaseModelOutputWithPast: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -62,6 +63,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" + if cache_position is not None: + assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item() assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 2c37566..a79f05c 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -153,7 +153,7 @@ async def iterate_rpc_inference( points: int, quant_type: QuantType, args_structure: Any = None, -) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: +) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]: assert len(cache_handles) == len(requested_backends) prefix_length = 0 @@ -224,7 +224,7 @@ async def iterate_rpc_inference( for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) ] can_push = not has_prompts - yield output_tensors, can_push + yield output_tensors, can_push, step_metadata # prepare for next step prefix_length += length_increment diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d8f0ec0..2465656 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout ) as cache_handles: background_tasks = set() - async for output_tensors, can_push in iterate_rpc_inference( + async for output_tensors, can_push, step_metadata in iterate_rpc_inference( requested_uids=requested_uids, requested_backends=requested_backends, active_adapter=self._get_active_adapter(metadata), @@ -186,7 +186,7 @@ class TransformerConnectionHandler(ConnectionHandler): args_structure=args_structure, ): if can_push: - task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) + task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors)