2023-12-11 21:53:30 +00:00
|
|
|
from typing import Any, Iterator, List, Mapping, Optional
|
|
|
|
|
|
|
|
import requests
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
|
from langchain_core.language_models.llms import LLM
|
|
|
|
from langchain_core.outputs import GenerationChunk
|
|
|
|
from requests.exceptions import ConnectionError
|
|
|
|
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
|
|
|
|
|
|
class TitanTakeoff(LLM):
|
|
|
|
"""Wrapper around Titan Takeoff APIs."""
|
|
|
|
|
|
|
|
base_url: str = "http://localhost:8000"
|
|
|
|
"""Specifies the baseURL to use for the Titan Takeoff API.
|
|
|
|
Default = http://localhost:8000.
|
|
|
|
"""
|
|
|
|
|
|
|
|
generate_max_length: int = 128
|
|
|
|
"""Maximum generation length. Default = 128."""
|
|
|
|
|
|
|
|
sampling_topk: int = 1
|
|
|
|
"""Sample predictions from the top K most probable candidates. Default = 1."""
|
|
|
|
|
|
|
|
sampling_topp: float = 1.0
|
|
|
|
"""Sample from predictions whose cumulative probability exceeds this value.
|
|
|
|
Default = 1.0.
|
|
|
|
"""
|
|
|
|
|
|
|
|
sampling_temperature: float = 1.0
|
|
|
|
"""Sample with randomness. Bigger temperatures are associated with
|
|
|
|
more randomness and 'creativity'. Default = 1.0.
|
|
|
|
"""
|
|
|
|
|
|
|
|
repetition_penalty: float = 1.0
|
|
|
|
"""Penalise the generation of tokens that have been generated before.
|
|
|
|
Set to > 1 to penalize. Default = 1 (no penalty).
|
|
|
|
"""
|
|
|
|
|
|
|
|
no_repeat_ngram_size: int = 0
|
|
|
|
"""Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""
|
|
|
|
|
|
|
|
streaming: bool = False
|
|
|
|
"""Whether to stream the output. Default = False."""
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _default_params(self) -> Mapping[str, Any]:
|
|
|
|
"""Get the default parameters for calling Titan Takeoff Server."""
|
|
|
|
params = {
|
|
|
|
"generate_max_length": self.generate_max_length,
|
|
|
|
"sampling_topk": self.sampling_topk,
|
|
|
|
"sampling_topp": self.sampling_topp,
|
|
|
|
"sampling_temperature": self.sampling_temperature,
|
|
|
|
"repetition_penalty": self.repetition_penalty,
|
|
|
|
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
|
|
|
}
|
|
|
|
return params
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
"""Return type of llm."""
|
|
|
|
return "titan_takeoff"
|
|
|
|
|
|
|
|
def _call(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> str:
|
|
|
|
"""Call out to Titan Takeoff generate endpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt: The prompt to pass into the model.
|
|
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The string generated by the model.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
prompt = "What is the capital of the United Kingdom?"
|
|
|
|
response = model(prompt)
|
|
|
|
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
if self.streaming:
|
|
|
|
text_output = ""
|
|
|
|
for chunk in self._stream(
|
|
|
|
prompt=prompt,
|
|
|
|
stop=stop,
|
|
|
|
run_manager=run_manager,
|
|
|
|
):
|
|
|
|
text_output += chunk.text
|
|
|
|
return text_output
|
|
|
|
|
|
|
|
url = f"{self.base_url}/generate"
|
|
|
|
params = {"text": prompt, **self._default_params}
|
|
|
|
|
|
|
|
response = requests.post(url, json=params)
|
|
|
|
response.raise_for_status()
|
|
|
|
response.encoding = "utf-8"
|
|
|
|
text = ""
|
|
|
|
|
|
|
|
if "message" in response.json():
|
|
|
|
text = response.json()["message"]
|
|
|
|
else:
|
|
|
|
raise ValueError("Something went wrong.")
|
|
|
|
if stop is not None:
|
|
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
return text
|
|
|
|
except ConnectionError:
|
|
|
|
raise ConnectionError(
|
|
|
|
"Could not connect to Titan Takeoff server. \
|
|
|
|
Please make sure that the server is running."
|
|
|
|
)
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Iterator[GenerationChunk]:
|
|
|
|
"""Call out to Titan Takeoff stream endpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt: The prompt to pass into the model.
|
|
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The string generated by the model.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
A dictionary like object containing a string token.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
prompt = "What is the capital of the United Kingdom?"
|
|
|
|
response = model(prompt)
|
|
|
|
|
|
|
|
"""
|
|
|
|
url = f"{self.base_url}/generate_stream"
|
|
|
|
params = {"text": prompt, **self._default_params}
|
|
|
|
|
|
|
|
response = requests.post(url, json=params, stream=True)
|
|
|
|
response.encoding = "utf-8"
|
|
|
|
for text in response.iter_content(chunk_size=1, decode_unicode=True):
|
|
|
|
if text:
|
|
|
|
chunk = GenerationChunk(text=text)
|
|
|
|
if run_manager:
|
|
|
|
run_manager.on_llm_new_token(token=chunk.text)
|
2024-03-05 20:54:26 +00:00
|
|
|
yield chunk
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
"""Get the identifying parameters."""
|
|
|
|
return {"base_url": self.base_url, **{}, **self._default_params}
|