mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Adding LLM wrapper for Kobold AI (#7560)
- Description: add wrapper that lets you use KoboldAI api in langchain - Issue: n/a - Dependencies: none extra, just what exists in lanchain - Tag maintainer: @baskaryan - Twitter handle: @zanzibased --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
603a0bea29
commit
50316f6477
@ -0,0 +1,88 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "FPF4vhdZyJ7S"
|
||||
},
|
||||
"source": [
|
||||
"# KoboldAI API\n",
|
||||
"\n",
|
||||
"[KoboldAI](https://github.com/KoboldAI/KoboldAI-Client) is a \"a browser-based front-end for AI-assisted writing with multiple local & remote AI models...\". It has a public and local API that is able to be used in langchain.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain with that API.\n",
|
||||
"\n",
|
||||
"Documentation can be found in the browser adding /api to the end of your endpoint (i.e http://127.0.0.1/:5000/api).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "lyzOsRRTf_Vr"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import KoboldApiLLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1a_H7mvfy51O"
|
||||
},
|
||||
"source": [
|
||||
"Replace the endpoint seen below with the one shown in the output after starting the webui with --api or --public-api\n",
|
||||
"\n",
|
||||
"Optionally, you can pass in parameters like temperature or max_length"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "g3vGebq8f_Vr"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = KoboldApiLLM(endpoint=\"http://192.168.1.144:5000\", max_length=80)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "sPxNGGiDf_Vr",
|
||||
"outputId": "024a1d62-3cd7-49a8-c6a8-5278224d02ef"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = llm(\"### Instruction:\\nWhat is the first book of the bible?\\n### Response:\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
@ -29,6 +29,7 @@ from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
||||
from langchain.llms.human import HumanInputLLM
|
||||
from langchain.llms.koboldai import KoboldApiLLM
|
||||
from langchain.llms.llamacpp import LlamaCpp
|
||||
from langchain.llms.manifest import ManifestWrapper
|
||||
from langchain.llms.modal import Modal
|
||||
@ -81,6 +82,7 @@ __all__ = [
|
||||
"HuggingFacePipeline",
|
||||
"HuggingFaceTextGenInference",
|
||||
"HumanInputLLM",
|
||||
"KoboldApiLLM",
|
||||
"LlamaCpp",
|
||||
"TextGen",
|
||||
"ManifestWrapper",
|
||||
@ -136,6 +138,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"huggingface_pipeline": HuggingFacePipeline,
|
||||
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
||||
"human-input": HumanInputLLM,
|
||||
"koboldai": KoboldApiLLM,
|
||||
"llamacpp": LlamaCpp,
|
||||
"textgen": TextGen,
|
||||
"modal": Modal,
|
||||
|
200
langchain/llms/koboldai.py
Normal file
200
langchain/llms/koboldai.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Wrapper around KoboldAI API."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def clean_url(url: str) -> str:
|
||||
"""Remove trailing slash and /api from url if present."""
|
||||
if url.endswith("/api"):
|
||||
return url[:-4]
|
||||
elif url.endswith("/"):
|
||||
return url[:-1]
|
||||
else:
|
||||
return url
|
||||
|
||||
|
||||
class KoboldApiLLM(LLM):
|
||||
"""
|
||||
A class that acts as a wrapper for the Kobold API language model.
|
||||
|
||||
It includes several fields that can be used to control the text generation process.
|
||||
|
||||
To use this class, instantiate it with the required parameters and call it with a
|
||||
prompt to generate text. For example:
|
||||
|
||||
kobold = KoboldApiLLM(endpoint="http://localhost:5000")
|
||||
result = kobold("Write a story about a dragon.")
|
||||
|
||||
This will send a POST request to the Kobold API with the provided prompt and
|
||||
generate text.
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
"""The API endpoint to use for generating text."""
|
||||
|
||||
use_story: Optional[bool] = False
|
||||
""" Whether or not to use the story from the KoboldAI GUI when generating text. """
|
||||
|
||||
use_authors_note: Optional[bool] = False
|
||||
"""Whether to use the author's note from the KoboldAI GUI when generating text.
|
||||
|
||||
This has no effect unless use_story is also enabled.
|
||||
"""
|
||||
|
||||
use_world_info: Optional[bool] = False
|
||||
"""Whether to use the world info from the KoboldAI GUI when generating text."""
|
||||
|
||||
use_memory: Optional[bool] = False
|
||||
"""Whether to use the memory from the KoboldAI GUI when generating text."""
|
||||
|
||||
max_context_length: Optional[int] = 1600
|
||||
"""Maximum number of tokens to send to the model.
|
||||
|
||||
minimum: 1
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = 80
|
||||
"""Number of tokens to generate.
|
||||
|
||||
maximum: 512
|
||||
minimum: 1
|
||||
"""
|
||||
|
||||
rep_pen: Optional[float] = 1.12
|
||||
"""Base repetition penalty value.
|
||||
|
||||
minimum: 1
|
||||
"""
|
||||
|
||||
rep_pen_range: Optional[int] = 1024
|
||||
"""Repetition penalty range.
|
||||
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
rep_pen_slope: Optional[float] = 0.9
|
||||
"""Repetition penalty slope.
|
||||
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = 0.6
|
||||
"""Temperature value.
|
||||
|
||||
exclusiveMinimum: 0
|
||||
"""
|
||||
|
||||
tfs: Optional[float] = 0.9
|
||||
"""Tail free sampling value.
|
||||
|
||||
maximum: 1
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
top_a: Optional[float] = 0.9
|
||||
"""Top-a sampling value.
|
||||
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
top_p: Optional[float] = 0.95
|
||||
"""Top-p sampling value.
|
||||
|
||||
maximum: 1
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
top_k: Optional[int] = 0
|
||||
"""Top-k sampling value.
|
||||
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
typical: Optional[float] = 0.5
|
||||
"""Typical sampling value.
|
||||
|
||||
maximum: 1
|
||||
minimum: 0
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "koboldai"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the API and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
|
||||
Returns:
|
||||
The generated text.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import KoboldApiLLM
|
||||
|
||||
llm = KoboldApiLLM(endpoint="http://localhost:5000")
|
||||
llm("Write a story about dragons.")
|
||||
"""
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"use_story": self.use_story,
|
||||
"use_authors_note": self.use_authors_note,
|
||||
"use_world_info": self.use_world_info,
|
||||
"use_memory": self.use_memory,
|
||||
"max_context_length": self.max_context_length,
|
||||
"max_length": self.max_length,
|
||||
"rep_pen": self.rep_pen,
|
||||
"rep_pen_range": self.rep_pen_range,
|
||||
"rep_pen_slope": self.rep_pen_slope,
|
||||
"temperature": self.temperature,
|
||||
"tfs": self.tfs,
|
||||
"top_a": self.top_a,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"typical": self.typical,
|
||||
}
|
||||
|
||||
if stop is not None:
|
||||
data["stop_sequence"] = stop
|
||||
|
||||
response = requests.post(
|
||||
f"{clean_url(self.endpoint)}/api/v1/generate", json=data
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
|
||||
if (
|
||||
"results" in json_response
|
||||
and len(json_response["results"]) > 0
|
||||
and "text" in json_response["results"][0]
|
||||
):
|
||||
text = json_response["results"][0]["text"].strip()
|
||||
|
||||
if stop is not None:
|
||||
for sequence in stop:
|
||||
if text.endswith(sequence):
|
||||
text = text[: -len(sequence)].rstrip()
|
||||
|
||||
return text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected response format from Kobold API: {json_response}"
|
||||
)
|
Loading…
Reference in New Issue
Block a user