|
|
|
@ -1,10 +1,13 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, Dict, Iterator, List, Optional
|
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
from langchain.pydantic_v1 import Field
|
|
|
|
|
from langchain.schema.output import GenerationChunk
|
|
|
|
@ -224,6 +227,54 @@ class TextGen(LLM):
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
async def _acall(
|
|
|
|
|
self,
|
|
|
|
|
prompt: str,
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Call the textgen web API and return the output.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: The prompt to use for generation.
|
|
|
|
|
stop: A list of strings to stop generation when encountered.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The generated text.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
from langchain.llms import TextGen
|
|
|
|
|
llm = TextGen(model_url="http://localhost:5000")
|
|
|
|
|
llm("Write a story about llamas.")
|
|
|
|
|
"""
|
|
|
|
|
if self.streaming:
|
|
|
|
|
combined_text_output = ""
|
|
|
|
|
async for chunk in self._astream(
|
|
|
|
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
):
|
|
|
|
|
combined_text_output += chunk.text
|
|
|
|
|
print(prompt + combined_text_output)
|
|
|
|
|
result = combined_text_output
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
url = f"{self.model_url}/api/v1/generate"
|
|
|
|
|
params = self._get_parameters(stop)
|
|
|
|
|
request = params.copy()
|
|
|
|
|
request["prompt"] = prompt
|
|
|
|
|
response = requests.post(url, json=request)
|
|
|
|
|
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
result = response.json()["results"][0]["text"]
|
|
|
|
|
print(prompt + result)
|
|
|
|
|
else:
|
|
|
|
|
print(f"ERROR: response: {response}")
|
|
|
|
|
result = ""
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
|
self,
|
|
|
|
|
prompt: str,
|
|
|
|
@ -296,3 +347,76 @@ class TextGen(LLM):
|
|
|
|
|
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(token=chunk.text)
|
|
|
|
|
|
|
|
|
|
async def _astream(
|
|
|
|
|
self,
|
|
|
|
|
prompt: str,
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> AsyncIterator[GenerationChunk]:
|
|
|
|
|
"""Yields results objects as they are generated in real time.
|
|
|
|
|
|
|
|
|
|
It also calls the callback manager's on_llm_new_token event with
|
|
|
|
|
similar parameters to the OpenAI LLM class method of the same name.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: The prompts to pass into the model.
|
|
|
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A generator representing the stream of tokens being generated.
|
|
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
|
A dictionary like objects containing a string token and metadata.
|
|
|
|
|
See text-generation-webui docs and below for more.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
from langchain.llms import TextGen
|
|
|
|
|
llm = TextGen(
|
|
|
|
|
model_url = "ws://localhost:5005"
|
|
|
|
|
streaming=True
|
|
|
|
|
)
|
|
|
|
|
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
|
|
|
|
|
stop=["'","\n"]):
|
|
|
|
|
print(chunk, end='', flush=True)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
import websocket
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"The `websocket-client` package is required for streaming."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
params = {**self._get_parameters(stop), **kwargs}
|
|
|
|
|
|
|
|
|
|
url = f"{self.model_url}/api/v1/stream"
|
|
|
|
|
|
|
|
|
|
request = params.copy()
|
|
|
|
|
request["prompt"] = prompt
|
|
|
|
|
|
|
|
|
|
websocket_client = websocket.WebSocket()
|
|
|
|
|
|
|
|
|
|
websocket_client.connect(url)
|
|
|
|
|
|
|
|
|
|
websocket_client.send(json.dumps(request))
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
result = websocket_client.recv()
|
|
|
|
|
result = json.loads(result)
|
|
|
|
|
|
|
|
|
|
if result["event"] == "text_stream":
|
|
|
|
|
chunk = GenerationChunk(
|
|
|
|
|
text=result["text"],
|
|
|
|
|
generation_info=None,
|
|
|
|
|
)
|
|
|
|
|
yield chunk
|
|
|
|
|
elif result["event"] == "stream_end":
|
|
|
|
|
websocket_client.close()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if run_manager:
|
|
|
|
|
await run_manager.on_llm_new_token(token=chunk.text)
|
|
|
|
|