measurements
Your Name 11 months ago
parent c735dd7ba3
commit 7dc1aa5151

@ -1,5 +1,6 @@
from __future__ import annotations
import time
from collections import Counter
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union
@ -188,8 +189,11 @@ class _MergedInferenceStep:
assert len(inference_infos) == len(
optional_prompts
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
t0 = time.perf_counter()
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
torch.cuda.synchronize()
print(f"INFERENCE TIME: {time.perf_counter() - t0:.5f} s")
return (hidden_states,)

Loading…
Cancel
Save