You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/llms/koboldai.py

198 lines
5.0 KiB
Python

import logging
from typing import Any, Dict, List, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
logger = logging.getLogger(__name__)
def clean_url(url: str) -> str:
"""Remove trailing slash and /api from url if present."""
if url.endswith("/api"):
return url[:-4]
elif url.endswith("/"):
return url[:-1]
else:
return url
class KoboldApiLLM(LLM):
"""Kobold API language model.
It includes several fields that can be used to control the text generation process.
To use this class, instantiate it with the required parameters and call it with a
prompt to generate text. For example:
kobold = KoboldApiLLM(endpoint="http://localhost:5000")
result = kobold("Write a story about a dragon.")
This will send a POST request to the Kobold API with the provided prompt and
generate text.
"""
endpoint: str
"""The API endpoint to use for generating text."""
use_story: Optional[bool] = False
""" Whether or not to use the story from the KoboldAI GUI when generating text. """
use_authors_note: Optional[bool] = False
"""Whether to use the author's note from the KoboldAI GUI when generating text.
This has no effect unless use_story is also enabled.
"""
use_world_info: Optional[bool] = False
"""Whether to use the world info from the KoboldAI GUI when generating text."""
use_memory: Optional[bool] = False
"""Whether to use the memory from the KoboldAI GUI when generating text."""
max_context_length: Optional[int] = 1600
"""Maximum number of tokens to send to the model.
minimum: 1
"""
max_length: Optional[int] = 80
"""Number of tokens to generate.
maximum: 512
minimum: 1
"""
rep_pen: Optional[float] = 1.12
"""Base repetition penalty value.
minimum: 1
"""
rep_pen_range: Optional[int] = 1024
"""Repetition penalty range.
minimum: 0
"""
rep_pen_slope: Optional[float] = 0.9
"""Repetition penalty slope.
minimum: 0
"""
temperature: Optional[float] = 0.6
"""Temperature value.
exclusiveMinimum: 0
"""
tfs: Optional[float] = 0.9
"""Tail free sampling value.
maximum: 1
minimum: 0
"""
top_a: Optional[float] = 0.9
"""Top-a sampling value.
minimum: 0
"""
top_p: Optional[float] = 0.95
"""Top-p sampling value.
maximum: 1
minimum: 0
"""
top_k: Optional[int] = 0
"""Top-k sampling value.
minimum: 0
"""
typical: Optional[float] = 0.5
"""Typical sampling value.
maximum: 1
minimum: 0
"""
@property
def _llm_type(self) -> str:
return "koboldai"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain_community.llms import KoboldApiLLM
llm = KoboldApiLLM(endpoint="http://localhost:5000")
llm.invoke("Write a story about dragons.")
"""
data: Dict[str, Any] = {
"prompt": prompt,
"use_story": self.use_story,
"use_authors_note": self.use_authors_note,
"use_world_info": self.use_world_info,
"use_memory": self.use_memory,
"max_context_length": self.max_context_length,
"max_length": self.max_length,
"rep_pen": self.rep_pen,
"rep_pen_range": self.rep_pen_range,
"rep_pen_slope": self.rep_pen_slope,
"temperature": self.temperature,
"tfs": self.tfs,
"top_a": self.top_a,
"top_p": self.top_p,
"top_k": self.top_k,
"typical": self.typical,
}
if stop is not None:
data["stop_sequence"] = stop
response = requests.post(
f"{clean_url(self.endpoint)}/api/v1/generate", json=data
)
response.raise_for_status()
json_response = response.json()
if (
"results" in json_response
and len(json_response["results"]) > 0
and "text" in json_response["results"][0]
):
text = json_response["results"][0]["text"].strip()
if stop is not None:
for sequence in stop:
if text.endswith(sequence):
text = text[: -len(sequence)].rstrip()
return text
else:
raise ValueError(
f"Unexpected response format from Kobold API: {json_response}"
)