diff --git a/libs/langchain/langchain/llms/textgen.py b/libs/langchain/langchain/llms/textgen.py index 0d846ecc88..5f83dc08b9 100644 --- a/libs/langchain/langchain/llms/textgen.py +++ b/libs/langchain/langchain/llms/textgen.py @@ -1,10 +1,13 @@ import json import logging -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional import requests -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.pydantic_v1 import Field from langchain.schema.output import GenerationChunk @@ -224,6 +227,54 @@ class TextGen(LLM): return result + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the textgen web API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen(model_url="http://localhost:5000") + llm("Write a story about llamas.") + """ + if self.streaming: + combined_text_output = "" + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + combined_text_output += chunk.text + print(prompt + combined_text_output) + result = combined_text_output + + else: + url = f"{self.model_url}/api/v1/generate" + params = self._get_parameters(stop) + request = params.copy() + request["prompt"] = prompt + response = requests.post(url, json=request) + + if response.status_code == 200: + result = response.json()["results"][0]["text"] + print(prompt + result) + else: + print(f"ERROR: response: {response}") + result = "" + + return result + def _stream( self, prompt: str, @@ -296,3 +347,76 @@ class TextGen(LLM): if run_manager: run_manager.on_llm_new_token(token=chunk.text) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + """Yields results objects as they are generated in real time. + + It also calls the callback manager's on_llm_new_token event with + similar parameters to the OpenAI LLM class method of the same name. + + Args: + prompt: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + A generator representing the stream of tokens being generated. + + Yields: + A dictionary like objects containing a string token and metadata. + See text-generation-webui docs and below for more. + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen( + model_url = "ws://localhost:5005" + streaming=True + ) + for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", + stop=["'","\n"]): + print(chunk, end='', flush=True) + + """ + try: + import websocket + except ImportError: + raise ImportError( + "The `websocket-client` package is required for streaming." + ) + + params = {**self._get_parameters(stop), **kwargs} + + url = f"{self.model_url}/api/v1/stream" + + request = params.copy() + request["prompt"] = prompt + + websocket_client = websocket.WebSocket() + + websocket_client.connect(url) + + websocket_client.send(json.dumps(request)) + + while True: + result = websocket_client.recv() + result = json.loads(result) + + if result["event"] == "text_stream": + chunk = GenerationChunk( + text=result["text"], + generation_info=None, + ) + yield chunk + elif result["event"] == "stream_end": + websocket_client.close() + return + + if run_manager: + await run_manager.on_llm_new_token(token=chunk.text)