TextGen is missing async methods. (#9986)

Adding _acall and _astream method that were missing. Preventing
streaming during async executions.

 @rlancemartin.
pull/10155/head
German Martin 1 year ago committed by GitHub
parent f4bed8a04c
commit cf5a50469f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,13 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field
from langchain.schema.output import GenerationChunk
@ -224,6 +227,54 @@ class TextGen(LLM):
return result
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
print(prompt + combined_text_output)
result = combined_text_output
else:
url = f"{self.model_url}/api/v1/generate"
params = self._get_parameters(stop)
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()["results"][0]["text"]
print(prompt + result)
else:
print(f"ERROR: response: {response}")
result = ""
return result
def _stream(
self,
prompt: str,
@ -296,3 +347,76 @@ class TextGen(LLM):
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
try:
import websocket
except ImportError:
raise ImportError(
"The `websocket-client` package is required for streaming."
)
params = {**self._get_parameters(stop), **kwargs}
url = f"{self.model_url}/api/v1/stream"
request = params.copy()
request["prompt"] = prompt
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = GenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text)

Loading…
Cancel
Save