|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
from collections import Counter
|
|
|
|
|
from itertools import chain
|
|
|
|
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
@ -188,8 +189,11 @@ class _MergedInferenceStep:
|
|
|
|
|
assert len(inference_infos) == len(
|
|
|
|
|
optional_prompts
|
|
|
|
|
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
|
|
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
|
|
|
|
|
if optional_prompt is not None:
|
|
|
|
|
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
|
|
|
|
|
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
print(f"INFERENCE TIME: {time.perf_counter() - t0:.5f} s")
|
|
|
|
|
return (hidden_states,)
|
|
|
|
|