from typing import Any, Iterator, List, Mapping, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from requests.exceptions import ConnectionError from langchain_community.llms.utils import enforce_stop_tokens class TitanTakeoffPro(LLM): """Titan Takeoff Pro is a language model that can be used to generate text.""" base_url: Optional[str] = "http://localhost:3000" """Specifies the baseURL to use for the Titan Takeoff Pro API. Default = http://localhost:3000. """ max_new_tokens: Optional[int] = None """Maximum tokens generated.""" min_new_tokens: Optional[int] = None """Minimum tokens generated.""" sampling_topk: Optional[int] = None """Sample predictions from the top K most probable candidates.""" sampling_topp: Optional[float] = None """Sample from predictions whose cumulative probability exceeds this value. """ sampling_temperature: Optional[float] = None """Sample with randomness. Bigger temperatures are associated with more randomness and 'creativity'. """ repetition_penalty: Optional[float] = None """Penalise the generation of tokens that have been generated before. Set to > 1 to penalize. """ regex_string: Optional[str] = None """A regex string for constrained generation.""" no_repeat_ngram_size: Optional[int] = None """Prevent repetitions of ngrams of this size. Default = 0 (turned off).""" streaming: bool = False """Whether to stream the output. Default = False.""" @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Titan Takeoff Server (Pro).""" return { **( {"regex_string": self.regex_string} if self.regex_string is not None else {} ), **( {"sampling_temperature": self.sampling_temperature} if self.sampling_temperature is not None else {} ), **( {"sampling_topp": self.sampling_topp} if self.sampling_topp is not None else {} ), **( {"repetition_penalty": self.repetition_penalty} if self.repetition_penalty is not None else {} ), **( {"max_new_tokens": self.max_new_tokens} if self.max_new_tokens is not None else {} ), **( {"min_new_tokens": self.min_new_tokens} if self.min_new_tokens is not None else {} ), **( {"sampling_topk": self.sampling_topk} if self.sampling_topk is not None else {} ), **( {"no_repeat_ngram_size": self.no_repeat_ngram_size} if self.no_repeat_ngram_size is not None else {} ), } @property def _llm_type(self) -> str: """Return type of llm.""" return "titan_takeoff_pro" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to Titan Takeoff (Pro) generate endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python prompt = "What is the capital of the United Kingdom?" response = model(prompt) """ try: if self.streaming: text_output = "" for chunk in self._stream( prompt=prompt, stop=stop, run_manager=run_manager, ): text_output += chunk.text return text_output url = f"{self.base_url}/generate" params = {"text": prompt, **self._default_params} response = requests.post(url, json=params) response.raise_for_status() response.encoding = "utf-8" text = "" if "text" in response.json(): text = response.json()["text"] text = text.replace("", "") else: raise ValueError("Something went wrong.") if stop is not None: text = enforce_stop_tokens(text, stop) return text except ConnectionError: raise ConnectionError( "Could not connect to Titan Takeoff (Pro) server. \ Please make sure that the server is running." ) def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Call out to Titan Takeoff (Pro) stream endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Yields: A dictionary like object containing a string token. Example: .. code-block:: python prompt = "What is the capital of the United Kingdom?" response = model(prompt) """ url = f"{self.base_url}/generate_stream" params = {"text": prompt, **self._default_params} response = requests.post(url, json=params, stream=True) response.encoding = "utf-8" buffer = "" for text in response.iter_content(chunk_size=1, decode_unicode=True): buffer += text if "data:" in buffer: # Remove the first instance of "data:" from the buffer. if buffer.startswith("data:"): buffer = "" if len(buffer.split("data:", 1)) == 2: content, _ = buffer.split("data:", 1) buffer = content.rstrip("\n") # Trim the buffer to only have content after the "data:" part. if buffer: # Ensure that there's content to process. chunk = GenerationChunk(text=buffer) buffer = "" # Reset buffer for the next set of data. yield chunk if run_manager: run_manager.on_llm_new_token(token=chunk.text) # Yield any remaining content in the buffer. if buffer: chunk = GenerationChunk(text=buffer.replace("", "")) yield chunk if run_manager: run_manager.on_llm_new_token(token=chunk.text) @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {"base_url": self.base_url, **{}, **self._default_params}