mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
162 lines
5.1 KiB
Python
162 lines
5.1 KiB
Python
|
from typing import Any, Iterator, List, Mapping, Optional
|
||
|
|
||
|
import requests
|
||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||
|
from langchain_core.language_models.llms import LLM
|
||
|
from langchain_core.outputs import GenerationChunk
|
||
|
from requests.exceptions import ConnectionError
|
||
|
|
||
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||
|
|
||
|
|
||
|
class TitanTakeoff(LLM):
|
||
|
"""Wrapper around Titan Takeoff APIs."""
|
||
|
|
||
|
base_url: str = "http://localhost:8000"
|
||
|
"""Specifies the baseURL to use for the Titan Takeoff API.
|
||
|
Default = http://localhost:8000.
|
||
|
"""
|
||
|
|
||
|
generate_max_length: int = 128
|
||
|
"""Maximum generation length. Default = 128."""
|
||
|
|
||
|
sampling_topk: int = 1
|
||
|
"""Sample predictions from the top K most probable candidates. Default = 1."""
|
||
|
|
||
|
sampling_topp: float = 1.0
|
||
|
"""Sample from predictions whose cumulative probability exceeds this value.
|
||
|
Default = 1.0.
|
||
|
"""
|
||
|
|
||
|
sampling_temperature: float = 1.0
|
||
|
"""Sample with randomness. Bigger temperatures are associated with
|
||
|
more randomness and 'creativity'. Default = 1.0.
|
||
|
"""
|
||
|
|
||
|
repetition_penalty: float = 1.0
|
||
|
"""Penalise the generation of tokens that have been generated before.
|
||
|
Set to > 1 to penalize. Default = 1 (no penalty).
|
||
|
"""
|
||
|
|
||
|
no_repeat_ngram_size: int = 0
|
||
|
"""Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""
|
||
|
|
||
|
streaming: bool = False
|
||
|
"""Whether to stream the output. Default = False."""
|
||
|
|
||
|
@property
|
||
|
def _default_params(self) -> Mapping[str, Any]:
|
||
|
"""Get the default parameters for calling Titan Takeoff Server."""
|
||
|
params = {
|
||
|
"generate_max_length": self.generate_max_length,
|
||
|
"sampling_topk": self.sampling_topk,
|
||
|
"sampling_topp": self.sampling_topp,
|
||
|
"sampling_temperature": self.sampling_temperature,
|
||
|
"repetition_penalty": self.repetition_penalty,
|
||
|
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
||
|
}
|
||
|
return params
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
"""Return type of llm."""
|
||
|
return "titan_takeoff"
|
||
|
|
||
|
def _call(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> str:
|
||
|
"""Call out to Titan Takeoff generate endpoint.
|
||
|
|
||
|
Args:
|
||
|
prompt: The prompt to pass into the model.
|
||
|
stop: Optional list of stop words to use when generating.
|
||
|
|
||
|
Returns:
|
||
|
The string generated by the model.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
prompt = "What is the capital of the United Kingdom?"
|
||
|
response = model(prompt)
|
||
|
|
||
|
"""
|
||
|
try:
|
||
|
if self.streaming:
|
||
|
text_output = ""
|
||
|
for chunk in self._stream(
|
||
|
prompt=prompt,
|
||
|
stop=stop,
|
||
|
run_manager=run_manager,
|
||
|
):
|
||
|
text_output += chunk.text
|
||
|
return text_output
|
||
|
|
||
|
url = f"{self.base_url}/generate"
|
||
|
params = {"text": prompt, **self._default_params}
|
||
|
|
||
|
response = requests.post(url, json=params)
|
||
|
response.raise_for_status()
|
||
|
response.encoding = "utf-8"
|
||
|
text = ""
|
||
|
|
||
|
if "message" in response.json():
|
||
|
text = response.json()["message"]
|
||
|
else:
|
||
|
raise ValueError("Something went wrong.")
|
||
|
if stop is not None:
|
||
|
text = enforce_stop_tokens(text, stop)
|
||
|
return text
|
||
|
except ConnectionError:
|
||
|
raise ConnectionError(
|
||
|
"Could not connect to Titan Takeoff server. \
|
||
|
Please make sure that the server is running."
|
||
|
)
|
||
|
|
||
|
def _stream(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> Iterator[GenerationChunk]:
|
||
|
"""Call out to Titan Takeoff stream endpoint.
|
||
|
|
||
|
Args:
|
||
|
prompt: The prompt to pass into the model.
|
||
|
stop: Optional list of stop words to use when generating.
|
||
|
|
||
|
Returns:
|
||
|
The string generated by the model.
|
||
|
|
||
|
Yields:
|
||
|
A dictionary like object containing a string token.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
prompt = "What is the capital of the United Kingdom?"
|
||
|
response = model(prompt)
|
||
|
|
||
|
"""
|
||
|
url = f"{self.base_url}/generate_stream"
|
||
|
params = {"text": prompt, **self._default_params}
|
||
|
|
||
|
response = requests.post(url, json=params, stream=True)
|
||
|
response.encoding = "utf-8"
|
||
|
for text in response.iter_content(chunk_size=1, decode_unicode=True):
|
||
|
if text:
|
||
|
chunk = GenerationChunk(text=text)
|
||
|
yield chunk
|
||
|
if run_manager:
|
||
|
run_manager.on_llm_new_token(token=chunk.text)
|
||
|
|
||
|
@property
|
||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||
|
"""Get the identifying parameters."""
|
||
|
return {"base_url": self.base_url, **{}, **self._default_params}
|