Add LLM for ChatGLM(2)-6B API (#7774)

Description:
Add LLM for ChatGLM-6B & ChatGLM2-6B API

Related Issue: 
Will the langchain support ChatGLM? #4766
Add support for selfhost models like ChatGLM or transformer models #1780

Dependencies: 
No extra library install required. 
It wraps api call to a ChatGLM(2)-6B server(start with api.py), so api
endpoint is required to run.

Tag maintainer:  @mlot 

Any comments on this PR would be appreciated.
---------

Co-authored-by: mlot <limpo2000@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Liu Ming 2023-07-17 22:27:17 +08:00 committed by GitHub
parent 25e3d3f283
commit fa0a9e502a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 265 additions and 0 deletions

View File

@ -0,0 +1,121 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatGLM\n",
"\n",
"[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) is an open bilingual language model based on General Language Model (GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). \n",
"\n",
"[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n",
"\n",
"This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n",
"ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import ChatGLM\n",
"from langchain import PromptTemplate, LLMChain\n",
"\n",
"# import os"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"{question}\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# default endpoint_url for a local deployed ChatGLM api server\n",
"endpoint_url = \"http://127.0.0.1:8000\"\n",
"\n",
"# direct access endpoint in a proxied environment\n",
"# os.environ['NO_PROXY'] = '127.0.0.1'\n",
"\n",
"llm = ChatGLM(\n",
" endpoint_url=endpoint_url,\n",
" max_token=80000,\n",
" history=[[\"我将从美国到中国来旅游,出行前希望了解中国的城市\", \"欢迎问我任何问题。\"]],\n",
" top_p=0.9,\n",
" model_kwargs={\"sample_model_args\": False},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ChatGLM payload: {'prompt': '北京和上海两座城市有什么不同?', 'temperature': 0.1, 'history': [['我将从美国到中国来旅游,出行前希望了解中国的城市', '欢迎问我任何问题。']], 'max_length': 80000, 'top_p': 0.9, 'sample_model_args': False}\n"
]
},
{
"data": {
"text/plain": [
"'北京和上海是中国的两个首都,它们在许多方面都有所不同。\\n\\n北京是中国的政治和文化中心拥有悠久的历史和灿烂的文化。它是中国最重要的古都之一也是中国历史上最后一个封建王朝的都城。北京有许多著名的古迹和景点例如紫禁城、天安门广场和长城等。\\n\\n上海是中国最现代化的城市之一也是中国商业和金融中心。上海拥有许多国际知名的企业和金融机构同时也有许多著名的景点和美食。上海的外滩是一个历史悠久的商业区拥有许多欧式建筑和餐馆。\\n\\n除此之外北京和上海在交通和人口方面也有很大差异。北京是中国的首都人口众多交通拥堵问题较为严重。而上海是中国的商业和金融中心人口密度较低交通相对较为便利。\\n\\n总的来说北京和上海是两个拥有独特魅力和特点的城市可以根据自己的兴趣和时间来选择前往其中一座城市旅游。'"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"北京和上海两座城市有什么不同?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -14,6 +14,7 @@ from langchain.llms.baseten import Baseten
from langchain.llms.beam import Beam
from langchain.llms.bedrock import Bedrock
from langchain.llms.cerebriumai import CerebriumAI
from langchain.llms.chatglm import ChatGLM
from langchain.llms.clarifai import Clarifai
from langchain.llms.cohere import Cohere
from langchain.llms.ctransformers import CTransformers
@ -69,6 +70,7 @@ __all__ = [
"Bedrock",
"CTransformers",
"CerebriumAI",
"ChatGLM",
"Clarifai",
"Cohere",
"Databricks",
@ -125,6 +127,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"baseten": Baseten,
"beam": Beam,
"cerebriumai": CerebriumAI,
"chat_glm": ChatGLM,
"clarifai": Clarifai,
"cohere": Cohere,
"ctransformers": CTransformers,

123
langchain/llms/chatglm.py Normal file
View File

@ -0,0 +1,123 @@
from typing import Any, List, Mapping, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
class ChatGLM(LLM):
"""Wrapper around ChatGLM's LLM inference service.
Example:
.. code-block:: python
from langchain.llms import ChatGLM
endpoint_url = (
"http://127.0.0.1:8000"
)
ChatGLM_llm = ChatGLM(
endpoint_url=endpoint_url
)
"""
endpoint_url: str = "http://127.0.0.1:8000/"
"""Endpoint URL to use."""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
max_token: int = 20000
"""Max token allowed to pass to the model."""
temperature: float = 0.1
"""LLM model temperature from 0 to 10."""
history: List[List] = []
"""History of the conversation"""
top_p: float = 0.7
"""Top P for nucleus sampling from 0 to 1"""
@property
def _llm_type(self) -> str:
return "chat_glm"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_url": self.endpoint_url},
**{"model_kwargs": _model_kwargs},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to a ChatGLM LLM inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = chatglm_llm("Who are you?")
"""
_model_kwargs = self.model_kwargs or {}
# HTTP headers for authorization
headers = {"Content-Type": "application/json"}
payload = {
"prompt": prompt,
"temperature": self.temperature,
"history": self.history,
"max_length": self.max_token,
"top_p": self.top_p,
}
payload.update(_model_kwargs)
payload.update(kwargs)
# print("ChatGLM payload:", payload)
# call api
try:
response = requests.post(self.endpoint_url, headers=headers, json=payload)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
# print("ChatGLM resp:", response)
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
try:
parsed_response = response.json()
# Check if response content does exists
if isinstance(parsed_response, dict):
content_keys = "response"
if content_keys in parsed_response:
text = parsed_response[content_keys]
else:
raise ValueError(f"No content in response : {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference endpoint: {e}."
f"\nResponse: {response.text}"
)
if stop is not None:
text = enforce_stop_tokens(text, stop)
self.history = self.history + [[None, parsed_response["response"]]]
return text

View File

@ -0,0 +1,18 @@
"""Test ChatGLM API wrapper."""
from langchain.llms.chatglm import ChatGLM
from langchain.schema import LLMResult
def test_chatglm_call() -> None:
"""Test valid call to chatglm."""
llm = ChatGLM()
output = llm("北京和上海这两座城市有什么不同?")
assert isinstance(output, str)
def test_chatglm_generate() -> None:
"""Test valid call to chatglm."""
llm = ChatGLM()
output = llm.generate(["who are you"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)