mirror of https://github.com/hwchase17/langchain
Takeoff pro support (#12070)
**Description:** This PR adds support for the [Pro version of Titan Takeoff Server](https://docs.titanml.co/docs/category/pro-features). Users of the Pro version will have to import the TitanTakeoffPro model, which is different from TitanTakeoff. **Issue:** Also minor fixes to docs for Titan Takeoff (Community version) **Dependencies:** No additional dependencies **Twitter handle:** @becoming_blake @baskaryan @hwchase17pull/12352/head
parent
4e47fe1dce
commit
b9410f2b6f
@ -0,0 +1,100 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Titan Takeoff Pro\n",
|
||||||
|
"\n",
|
||||||
|
"`TitanML` helps businesses build and deploy better, smaller, cheaper, and faster NLP models through our training, compression, and inference optimization platform.\n",
|
||||||
|
"\n",
|
||||||
|
">Note: These docs are for the Pro version of Titan Takeoff. For the community version, see the page for Titan Takeoff.\n",
|
||||||
|
"\n",
|
||||||
|
"Our inference server, [Titan Takeoff (Pro Version)](https://docs.titanml.co/docs/titan-takeoff/pro-features/feature-comparison) enables deployment of LLMs locally on your hardware in a single command. Most generative model architectures are supported, such as Falcon, Llama 2, GPT2, T5 and many more."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Example usage\n",
|
||||||
|
"Here are some helpful examples to get started using the Pro version of Titan Takeoff Server.\n",
|
||||||
|
"No parameters are needed by default, but a baseURL that points to your desired URL where Takeoff is running can be specified and generation parameters can be supplied."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import TitanTakeoffPro\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"from langchain.callbacks.manager import CallbackManager\n",
|
||||||
|
"\n",
|
||||||
|
"# Example 1: Basic use\n",
|
||||||
|
"llm = TitanTakeoffPro()\n",
|
||||||
|
"output = llm(\"What is the weather in London in August?\")\n",
|
||||||
|
"print(output)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Example 2: Specifying a port and other generation parameters\n",
|
||||||
|
"llm = TitanTakeoffPro(\n",
|
||||||
|
" base_url=\"http://localhost:3000\",\n",
|
||||||
|
" min_new_tokens=128,\n",
|
||||||
|
" max_new_tokens=512,\n",
|
||||||
|
" no_repeat_ngram_size=2,\n",
|
||||||
|
" sampling_topk= 1,\n",
|
||||||
|
" sampling_topp= 1.0,\n",
|
||||||
|
" sampling_temperature= 1.0,\n",
|
||||||
|
" repetition_penalty= 1.0,\n",
|
||||||
|
" regex_string= \"\",\n",
|
||||||
|
")\n",
|
||||||
|
"output = llm(\"What is the largest rainforest in the world?\")\n",
|
||||||
|
"print(output)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Example 3: Using generate for multiple inputs\n",
|
||||||
|
"llm = TitanTakeoffPro()\n",
|
||||||
|
"rich_output = llm.generate([\"What is Deep Learning?\", \"What is Machine Learning?\"])\n",
|
||||||
|
"print(rich_output.generations)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Example 4: Streaming output\n",
|
||||||
|
"llm = TitanTakeoffPro(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))\n",
|
||||||
|
"prompt = \"What is the capital of France?\"\n",
|
||||||
|
"llm(prompt)\n",
|
||||||
|
"\n",
|
||||||
|
"# Example 5: Using LCEL\n",
|
||||||
|
"llm = TitanTakeoffPro()\n",
|
||||||
|
"prompt = PromptTemplate.from_template(\"Tell me about {topic}\")\n",
|
||||||
|
"chain = prompt | llm\n",
|
||||||
|
"chain.invoke({\"topic\": \"the universe\"})"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -0,0 +1,215 @@
|
|||||||
|
from typing import Any, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import ConnectionError
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
|
class TitanTakeoffPro(LLM):
|
||||||
|
base_url: Optional[str] = "http://localhost:3000"
|
||||||
|
"""Specifies the baseURL to use for the Titan Takeoff Pro API.
|
||||||
|
Default = http://localhost:3000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_new_tokens: Optional[int] = None
|
||||||
|
"""Maximum tokens generated."""
|
||||||
|
|
||||||
|
min_new_tokens: Optional[int] = None
|
||||||
|
"""Minimum tokens generated."""
|
||||||
|
|
||||||
|
sampling_topk: Optional[int] = None
|
||||||
|
"""Sample predictions from the top K most probable candidates."""
|
||||||
|
|
||||||
|
sampling_topp: Optional[float] = None
|
||||||
|
"""Sample from predictions whose cumulative probability exceeds this value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sampling_temperature: Optional[float] = None
|
||||||
|
"""Sample with randomness. Bigger temperatures are associated with
|
||||||
|
more randomness and 'creativity'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
"""Penalise the generation of tokens that have been generated before.
|
||||||
|
Set to > 1 to penalize.
|
||||||
|
"""
|
||||||
|
|
||||||
|
regex_string: Optional[str] = None
|
||||||
|
"""A regex string for constrained generation."""
|
||||||
|
|
||||||
|
no_repeat_ngram_size: Optional[int] = None
|
||||||
|
"""Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the output. Default = False."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the default parameters for calling Titan Takeoff Server (Pro)."""
|
||||||
|
return {
|
||||||
|
**(
|
||||||
|
{"regex_string": self.regex_string}
|
||||||
|
if self.regex_string is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"sampling_temperature": self.sampling_temperature}
|
||||||
|
if self.sampling_temperature is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"sampling_topp": self.sampling_topp}
|
||||||
|
if self.sampling_topp is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"repetition_penalty": self.repetition_penalty}
|
||||||
|
if self.repetition_penalty is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"max_new_tokens": self.max_new_tokens}
|
||||||
|
if self.max_new_tokens is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"min_new_tokens": self.min_new_tokens}
|
||||||
|
if self.min_new_tokens is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"sampling_topk": self.sampling_topk}
|
||||||
|
if self.sampling_topk is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"no_repeat_ngram_size": self.no_repeat_ngram_size}
|
||||||
|
if self.no_repeat_ngram_size is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "titan_takeoff_pro"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to Titan Takeoff (Pro) generate endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt = "What is the capital of the United Kingdom?"
|
||||||
|
response = model(prompt)
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.streaming:
|
||||||
|
text_output = ""
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt=prompt,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,
|
||||||
|
):
|
||||||
|
text_output += chunk.text
|
||||||
|
return text_output
|
||||||
|
url = f"{self.base_url}/generate"
|
||||||
|
params = {"text": prompt, **self._default_params}
|
||||||
|
|
||||||
|
response = requests.post(url, json=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
response.encoding = "utf-8"
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
if "text" in response.json():
|
||||||
|
text = response.json()["text"]
|
||||||
|
text = text.replace("</s>", "")
|
||||||
|
else:
|
||||||
|
raise ValueError("Something went wrong.")
|
||||||
|
if stop is not None:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
|
except ConnectionError:
|
||||||
|
raise ConnectionError(
|
||||||
|
"Could not connect to Titan Takeoff (Pro) server. \
|
||||||
|
Please make sure that the server is running."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Call out to Titan Takeoff (Pro) stream endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A dictionary like object containing a string token.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt = "What is the capital of the United Kingdom?"
|
||||||
|
response = model(prompt)
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/generate_stream"
|
||||||
|
params = {"text": prompt, **self._default_params}
|
||||||
|
|
||||||
|
response = requests.post(url, json=params, stream=True)
|
||||||
|
response.encoding = "utf-8"
|
||||||
|
buffer = ""
|
||||||
|
for text in response.iter_content(chunk_size=1, decode_unicode=True):
|
||||||
|
buffer += text
|
||||||
|
if "data:" in buffer:
|
||||||
|
# Remove the first instance of "data:" from the buffer.
|
||||||
|
if buffer.startswith("data:"):
|
||||||
|
buffer = ""
|
||||||
|
if len(buffer.split("data:", 1)) == 2:
|
||||||
|
content, _ = buffer.split("data:", 1)
|
||||||
|
buffer = content.rstrip("\n")
|
||||||
|
# Trim the buffer to only have content after the "data:" part.
|
||||||
|
if buffer: # Ensure that there's content to process.
|
||||||
|
chunk = GenerationChunk(text=buffer)
|
||||||
|
buffer = "" # Reset buffer for the next set of data.
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token=chunk.text)
|
||||||
|
|
||||||
|
# Yield any remaining content in the buffer.
|
||||||
|
if buffer:
|
||||||
|
chunk = GenerationChunk(text=buffer.replace("</s>", ""))
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token=chunk.text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {"base_url": self.base_url, **{}, **self._default_params}
|
@ -0,0 +1,18 @@
|
|||||||
|
"""Test Titan Takeoff wrapper."""
|
||||||
|
|
||||||
|
|
||||||
|
import responses
|
||||||
|
|
||||||
|
from langchain.llms.titan_takeoff_pro import TitanTakeoffPro
|
||||||
|
|
||||||
|
|
||||||
|
@responses.activate
|
||||||
|
def test_titan_takeoff_pro_call() -> None:
|
||||||
|
"""Test valid call to Titan Takeoff."""
|
||||||
|
url = "http://localhost:3000/generate"
|
||||||
|
responses.add(responses.POST, url, json={"message": "2 + 2 is 4"}, status=200)
|
||||||
|
|
||||||
|
# response = requests.post(url)
|
||||||
|
llm = TitanTakeoffPro()
|
||||||
|
output = llm("What is 2 + 2?")
|
||||||
|
assert isinstance(output, str)
|
Loading…
Reference in New Issue