mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add LLM for ChatGLM(2)-6B API (#7774)
Description: Add LLM for ChatGLM-6B & ChatGLM2-6B API Related Issue: Will the langchain support ChatGLM? #4766 Add support for selfhost models like ChatGLM or transformer models #1780 Dependencies: No extra library install required. It wraps api call to a ChatGLM(2)-6B server(start with api.py), so api endpoint is required to run. Tag maintainer: @mlot Any comments on this PR would be appreciated. --------- Co-authored-by: mlot <limpo2000@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
25e3d3f283
commit
fa0a9e502a
@ -0,0 +1,121 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatGLM\n",
|
||||
"\n",
|
||||
"[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) is an open bilingual language model based on General Language Model (GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). \n",
|
||||
"\n",
|
||||
"[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n",
|
||||
"ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import ChatGLM\n",
|
||||
"from langchain import PromptTemplate, LLMChain\n",
|
||||
"\n",
|
||||
"# import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"{question}\"\"\"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# default endpoint_url for a local deployed ChatGLM api server\n",
|
||||
"endpoint_url = \"http://127.0.0.1:8000\"\n",
|
||||
"\n",
|
||||
"# direct access endpoint in a proxied environment\n",
|
||||
"# os.environ['NO_PROXY'] = '127.0.0.1'\n",
|
||||
"\n",
|
||||
"llm = ChatGLM(\n",
|
||||
" endpoint_url=endpoint_url,\n",
|
||||
" max_token=80000,\n",
|
||||
" history=[[\"我将从美国到中国来旅游,出行前希望了解中国的城市\", \"欢迎问我任何问题。\"]],\n",
|
||||
" top_p=0.9,\n",
|
||||
" model_kwargs={\"sample_model_args\": False},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ChatGLM payload: {'prompt': '北京和上海两座城市有什么不同?', 'temperature': 0.1, 'history': [['我将从美国到中国来旅游,出行前希望了解中国的城市', '欢迎问我任何问题。']], 'max_length': 80000, 'top_p': 0.9, 'sample_model_args': False}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'北京和上海是中国的两个首都,它们在许多方面都有所不同。\\n\\n北京是中国的政治和文化中心,拥有悠久的历史和灿烂的文化。它是中国最重要的古都之一,也是中国历史上最后一个封建王朝的都城。北京有许多著名的古迹和景点,例如紫禁城、天安门广场和长城等。\\n\\n上海是中国最现代化的城市之一,也是中国商业和金融中心。上海拥有许多国际知名的企业和金融机构,同时也有许多著名的景点和美食。上海的外滩是一个历史悠久的商业区,拥有许多欧式建筑和餐馆。\\n\\n除此之外,北京和上海在交通和人口方面也有很大差异。北京是中国的首都,人口众多,交通拥堵问题较为严重。而上海是中国的商业和金融中心,人口密度较低,交通相对较为便利。\\n\\n总的来说,北京和上海是两个拥有独特魅力和特点的城市,可以根据自己的兴趣和时间来选择前往其中一座城市旅游。'"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"北京和上海两座城市有什么不同?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain-dev",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -14,6 +14,7 @@ from langchain.llms.baseten import Baseten
|
||||
from langchain.llms.beam import Beam
|
||||
from langchain.llms.bedrock import Bedrock
|
||||
from langchain.llms.cerebriumai import CerebriumAI
|
||||
from langchain.llms.chatglm import ChatGLM
|
||||
from langchain.llms.clarifai import Clarifai
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.ctransformers import CTransformers
|
||||
@ -69,6 +70,7 @@ __all__ = [
|
||||
"Bedrock",
|
||||
"CTransformers",
|
||||
"CerebriumAI",
|
||||
"ChatGLM",
|
||||
"Clarifai",
|
||||
"Cohere",
|
||||
"Databricks",
|
||||
@ -125,6 +127,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"baseten": Baseten,
|
||||
"beam": Beam,
|
||||
"cerebriumai": CerebriumAI,
|
||||
"chat_glm": ChatGLM,
|
||||
"clarifai": Clarifai,
|
||||
"cohere": Cohere,
|
||||
"ctransformers": CTransformers,
|
||||
|
123
langchain/llms/chatglm.py
Normal file
123
langchain/llms/chatglm.py
Normal file
@ -0,0 +1,123 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class ChatGLM(LLM):
|
||||
"""Wrapper around ChatGLM's LLM inference service.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import ChatGLM
|
||||
endpoint_url = (
|
||||
"http://127.0.0.1:8000"
|
||||
)
|
||||
ChatGLM_llm = ChatGLM(
|
||||
endpoint_url=endpoint_url
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = "http://127.0.0.1:8000/"
|
||||
"""Endpoint URL to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
max_token: int = 20000
|
||||
"""Max token allowed to pass to the model."""
|
||||
temperature: float = 0.1
|
||||
"""LLM model temperature from 0 to 10."""
|
||||
history: List[List] = []
|
||||
"""History of the conversation"""
|
||||
top_p: float = 0.7
|
||||
"""Top P for nucleus sampling from 0 to 1"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "chat_glm"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_url": self.endpoint_url},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to a ChatGLM LLM inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = chatglm_llm("Who are you?")
|
||||
"""
|
||||
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"temperature": self.temperature,
|
||||
"history": self.history,
|
||||
"max_length": self.max_token,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
payload.update(_model_kwargs)
|
||||
payload.update(kwargs)
|
||||
|
||||
# print("ChatGLM payload:", payload)
|
||||
|
||||
# call api
|
||||
try:
|
||||
response = requests.post(self.endpoint_url, headers=headers, json=payload)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
# print("ChatGLM resp:", response)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed with response: {response}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
# Check if response content does exists
|
||||
if isinstance(parsed_response, dict):
|
||||
content_keys = "response"
|
||||
if content_keys in parsed_response:
|
||||
text = parsed_response[content_keys]
|
||||
else:
|
||||
raise ValueError(f"No content in response : {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised during decoding response from inference endpoint: {e}."
|
||||
f"\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
self.history = self.history + [[None, parsed_response["response"]]]
|
||||
return text
|
18
tests/integration_tests/llms/test_chatglm.py
Normal file
18
tests/integration_tests/llms/test_chatglm.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""Test ChatGLM API wrapper."""
|
||||
from langchain.llms.chatglm import ChatGLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def test_chatglm_call() -> None:
|
||||
"""Test valid call to chatglm."""
|
||||
llm = ChatGLM()
|
||||
output = llm("北京和上海这两座城市有什么不同?")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_chatglm_generate() -> None:
|
||||
"""Test valid call to chatglm."""
|
||||
llm = ChatGLM()
|
||||
output = llm.generate(["who are you"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
Loading…
Reference in New Issue
Block a user