mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
7d216ad1e1
## PR title community[patch]: Invoke callback prior to yielding token ## PR message - Description: Invoke callback prior to yielding token in _stream_ method in llms/titan_takeoff_pro. - Issue: #16913 - Dependencies: None
218 lines
7.2 KiB
Python
218 lines
7.2 KiB
Python
from typing import Any, Iterator, List, Mapping, Optional
|
|
|
|
import requests
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.outputs import GenerationChunk
|
|
from requests.exceptions import ConnectionError
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
class TitanTakeoffPro(LLM):
|
|
"""Titan Takeoff Pro is a language model that can be used to generate text."""
|
|
|
|
base_url: Optional[str] = "http://localhost:3000"
|
|
"""Specifies the baseURL to use for the Titan Takeoff Pro API.
|
|
Default = http://localhost:3000.
|
|
"""
|
|
|
|
max_new_tokens: Optional[int] = None
|
|
"""Maximum tokens generated."""
|
|
|
|
min_new_tokens: Optional[int] = None
|
|
"""Minimum tokens generated."""
|
|
|
|
sampling_topk: Optional[int] = None
|
|
"""Sample predictions from the top K most probable candidates."""
|
|
|
|
sampling_topp: Optional[float] = None
|
|
"""Sample from predictions whose cumulative probability exceeds this value.
|
|
"""
|
|
|
|
sampling_temperature: Optional[float] = None
|
|
"""Sample with randomness. Bigger temperatures are associated with
|
|
more randomness and 'creativity'.
|
|
"""
|
|
|
|
repetition_penalty: Optional[float] = None
|
|
"""Penalise the generation of tokens that have been generated before.
|
|
Set to > 1 to penalize.
|
|
"""
|
|
|
|
regex_string: Optional[str] = None
|
|
"""A regex string for constrained generation."""
|
|
|
|
no_repeat_ngram_size: Optional[int] = None
|
|
"""Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""
|
|
|
|
streaming: bool = False
|
|
"""Whether to stream the output. Default = False."""
|
|
|
|
@property
|
|
def _default_params(self) -> Mapping[str, Any]:
|
|
"""Get the default parameters for calling Titan Takeoff Server (Pro)."""
|
|
return {
|
|
**(
|
|
{"regex_string": self.regex_string}
|
|
if self.regex_string is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"sampling_temperature": self.sampling_temperature}
|
|
if self.sampling_temperature is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"sampling_topp": self.sampling_topp}
|
|
if self.sampling_topp is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"repetition_penalty": self.repetition_penalty}
|
|
if self.repetition_penalty is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"max_new_tokens": self.max_new_tokens}
|
|
if self.max_new_tokens is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"min_new_tokens": self.min_new_tokens}
|
|
if self.min_new_tokens is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"sampling_topk": self.sampling_topk}
|
|
if self.sampling_topk is not None
|
|
else {}
|
|
),
|
|
**(
|
|
{"no_repeat_ngram_size": self.no_repeat_ngram_size}
|
|
if self.no_repeat_ngram_size is not None
|
|
else {}
|
|
),
|
|
}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "titan_takeoff_pro"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call out to Titan Takeoff (Pro) generate endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
prompt = "What is the capital of the United Kingdom?"
|
|
response = model(prompt)
|
|
|
|
"""
|
|
try:
|
|
if self.streaming:
|
|
text_output = ""
|
|
for chunk in self._stream(
|
|
prompt=prompt,
|
|
stop=stop,
|
|
run_manager=run_manager,
|
|
):
|
|
text_output += chunk.text
|
|
return text_output
|
|
url = f"{self.base_url}/generate"
|
|
params = {"text": prompt, **self._default_params}
|
|
|
|
response = requests.post(url, json=params)
|
|
response.raise_for_status()
|
|
response.encoding = "utf-8"
|
|
|
|
text = ""
|
|
if "text" in response.json():
|
|
text = response.json()["text"]
|
|
text = text.replace("</s>", "")
|
|
else:
|
|
raise ValueError("Something went wrong.")
|
|
if stop is not None:
|
|
text = enforce_stop_tokens(text, stop)
|
|
return text
|
|
except ConnectionError:
|
|
raise ConnectionError(
|
|
"Could not connect to Titan Takeoff (Pro) server. \
|
|
Please make sure that the server is running."
|
|
)
|
|
|
|
def _stream(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[GenerationChunk]:
|
|
"""Call out to Titan Takeoff (Pro) stream endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Yields:
|
|
A dictionary like object containing a string token.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
prompt = "What is the capital of the United Kingdom?"
|
|
response = model(prompt)
|
|
|
|
"""
|
|
url = f"{self.base_url}/generate_stream"
|
|
params = {"text": prompt, **self._default_params}
|
|
|
|
response = requests.post(url, json=params, stream=True)
|
|
response.encoding = "utf-8"
|
|
buffer = ""
|
|
for text in response.iter_content(chunk_size=1, decode_unicode=True):
|
|
buffer += text
|
|
if "data:" in buffer:
|
|
# Remove the first instance of "data:" from the buffer.
|
|
if buffer.startswith("data:"):
|
|
buffer = ""
|
|
if len(buffer.split("data:", 1)) == 2:
|
|
content, _ = buffer.split("data:", 1)
|
|
buffer = content.rstrip("\n")
|
|
# Trim the buffer to only have content after the "data:" part.
|
|
if buffer: # Ensure that there's content to process.
|
|
chunk = GenerationChunk(text=buffer)
|
|
buffer = "" # Reset buffer for the next set of data.
|
|
yield chunk
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(token=chunk.text)
|
|
|
|
# Yield any remaining content in the buffer.
|
|
if buffer:
|
|
chunk = GenerationChunk(text=buffer.replace("</s>", ""))
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(token=chunk.text)
|
|
yield chunk
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {"base_url": self.base_url, **{}, **self._default_params}
|