Fix p2p pushing in rpc_inference (by @miaoqijun ) , support transformers 4.38.2 (#563)

This pull request solves #560 using a solution proposed by @miaoqijun .
It also bumps transformers to the latest version to test with the latest code.

---------

Co-authored-by: Yingtong Dou <ytongdou@gmail.com>
pull/523/merge
justheuristic 1 month ago committed by GitHub
parent efee5d1fa8
commit 2ad0b2b936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -37,7 +37,7 @@ install_requires =
accelerate>=0.27.2
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
transformers==4.38.2 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2

@ -22,8 +22,8 @@ __version__ = "2.3.0.dev2"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"
def _override_bfloat16_mode_default():

@ -50,9 +50,15 @@ class OptimizedLlamaAttention(LlamaAttention):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
assert position_ids is None
if position_ids is None:
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
position_ids = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
).unsqueeze(0)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
@ -84,9 +90,8 @@ class OptimizedLlamaAttention(LlamaAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@ -160,6 +165,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -190,6 +197,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states

@ -47,6 +47,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> BaseModelOutputWithPast:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -62,6 +63,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
if cache_position is not None:
assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"

@ -153,7 +153,7 @@ async def iterate_rpc_inference(
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
assert len(cache_handles) == len(requested_backends)
prefix_length = 0
@ -224,7 +224,7 @@ async def iterate_rpc_inference(
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push
yield output_tensors, can_push, step_metadata
# prepare for next step
prefix_length += length_increment

@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
@ -186,7 +186,7 @@ class TransformerConnectionHandler(ConnectionHandler):
args_structure=args_structure,
):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)

Loading…
Cancel
Save