openai adapters (#8988)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Harrison Chase 2023-08-10 16:08:50 -07:00 committed by GitHub
parent 45f0f9460a
commit bb6fbf4c71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 649 additions and 72 deletions

View File

@ -0,0 +1,323 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "700a516b",
"metadata": {},
"source": [
"# OpenAI Adapter\n",
"\n",
"A lot of people get started with OpenAI but want to explore other models. LangChain's integrations with many model providers make this easy to do so. While LangChain has it's own message and model APIs, we've also made it as easy as possible to explore other models by exposing an adapter to adapt LangChain models to the OpenAI api.\n",
"\n",
"At the moment this only deals with output and does not return other information (token counts, stop reasons, etc)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6017f26a",
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"from langchain.adapters import openai as lc_openai"
]
},
{
"cell_type": "markdown",
"id": "b522ceda",
"metadata": {},
"source": [
"## ChatCompletion.create"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1d22eb61",
"metadata": {},
"outputs": [],
"source": [
"messages = [{\"role\": \"user\", \"content\": \"hi\"}]"
]
},
{
"cell_type": "markdown",
"id": "d550d3ad",
"metadata": {},
"source": [
"Original OpenAI call"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e1d27dfa",
"metadata": {},
"outputs": [],
"source": [
"result = openai.ChatCompletion.create(\n",
" messages=messages, \n",
" model=\"gpt-3.5-turbo\", \n",
" temperature=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "012d81ae",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant', 'content': 'Hello! How can I assist you today?'}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result[\"choices\"][0]['message'].to_dict_recursive()"
]
},
{
"cell_type": "markdown",
"id": "db5b5500",
"metadata": {},
"source": [
"LangChain OpenAI wrapper call"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "87c2d515",
"metadata": {},
"outputs": [],
"source": [
"lc_result = lc_openai.ChatCompletion.create(\n",
" messages=messages, \n",
" model=\"gpt-3.5-turbo\", \n",
" temperature=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c67a5ac8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant', 'content': 'Hello! How can I assist you today?'}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lc_result[\"choices\"][0]['message']"
]
},
{
"cell_type": "markdown",
"id": "034ba845",
"metadata": {},
"source": [
"Swapping out model providers"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7a2c011c",
"metadata": {},
"outputs": [],
"source": [
"lc_result = lc_openai.ChatCompletion.create(\n",
" messages=messages, \n",
" model=\"claude-2\", \n",
" temperature=0, \n",
" provider=\"ChatAnthropic\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "f7c94827",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant', 'content': ' Hello!'}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lc_result[\"choices\"][0]['message']"
]
},
{
"cell_type": "markdown",
"id": "cb3f181d",
"metadata": {},
"source": [
"## ChatCompletion.stream"
]
},
{
"cell_type": "markdown",
"id": "f7b8cd18",
"metadata": {},
"source": [
"Original OpenAI call"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "fd8cb1ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': ''}\n",
"{'content': 'Hello'}\n",
"{'content': '!'}\n",
"{'content': ' How'}\n",
"{'content': ' can'}\n",
"{'content': ' I'}\n",
"{'content': ' assist'}\n",
"{'content': ' you'}\n",
"{'content': ' today'}\n",
"{'content': '?'}\n",
"{}\n"
]
}
],
"source": [
"for c in openai.ChatCompletion.create(\n",
" messages = messages,\n",
" model=\"gpt-3.5-turbo\", \n",
" temperature=0,\n",
" stream=True\n",
"):\n",
" print(c[\"choices\"][0]['delta'].to_dict_recursive())"
]
},
{
"cell_type": "markdown",
"id": "0b2a076b",
"metadata": {},
"source": [
"LangChain OpenAI wrapper call"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "9521218c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': ''}\n",
"{'content': 'Hello'}\n",
"{'content': '!'}\n",
"{'content': ' How'}\n",
"{'content': ' can'}\n",
"{'content': ' I'}\n",
"{'content': ' assist'}\n",
"{'content': ' you'}\n",
"{'content': ' today'}\n",
"{'content': '?'}\n",
"{}\n"
]
}
],
"source": [
"for c in lc_openai.ChatCompletion.create(\n",
" messages = messages,\n",
" model=\"gpt-3.5-turbo\", \n",
" temperature=0,\n",
" stream=True\n",
"):\n",
" print(c[\"choices\"][0]['delta'])"
]
},
{
"cell_type": "markdown",
"id": "0fc39750",
"metadata": {},
"source": [
"Swapping out model providers"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "68f0214e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'role': 'assistant', 'content': ' Hello'}\n",
"{'content': '!'}\n",
"{}\n"
]
}
],
"source": [
"for c in lc_openai.ChatCompletion.create(\n",
" messages = messages,\n",
" model=\"claude-2\", \n",
" temperature=0,\n",
" stream=True,\n",
" provider=\"ChatAnthropic\",\n",
"):\n",
" print(c[\"choices\"][0]['delta'])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,208 @@
from __future__ import annotations
import importlib
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Mapping,
Sequence,
Union,
overload,
)
from typing_extensions import Literal
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
async def aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate."""
i = start
async for x in iterable:
yield i, x
i += 1
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
messages: List of dictionaries representing OpenAI messages
Returns:
List of LangChain BaseMessage objects.
"""
return [convert_dict_to_message(m) for m in messages]
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
_dict: Dict[str, Any] = {}
if isinstance(chunk, AIMessageChunk):
if i == 0:
# Only shows up in the first chunk
_dict["role"] = "assistant"
if "function_call" in chunk.additional_kwargs:
_dict["function_call"] = chunk.additional_kwargs["function_call"]
# If the first chunk is a function call, the content is not empty string,
# not missing, but None.
if i == 0:
_dict["content"] = None
else:
_dict["content"] = chunk.content
else:
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
# This only happens at the end of streams, and OpenAI returns as empty dict
if _dict == {"content": ""}:
_dict = {}
return {"choices": [{"delta": _dict}]}
class ChatCompletion:
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict:
...
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> Iterable:
...
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, Iterable]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
for i, c in enumerate(model_config.stream(converted_messages))
)
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict:
...
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> AsyncIterator:
...
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, AsyncIterator]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
async for i, c in aenumerate(model_config.astream(converted_messages))
)

View File

@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Optional, Set
import requests
from pydantic import Field, root_validator
from langchain.adapters.openai import convert_message_to_dict
from langchain.chat_models.openai import (
ChatOpenAI,
_convert_message_to_dict,
_import_tiktoken,
)
from langchain.schema.messages import BaseMessage
@ -178,7 +178,7 @@ class ChatAnyscale(ChatOpenAI):
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@ -19,6 +19,7 @@ from typing import (
from pydantic import Field, root_validator
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -27,17 +28,12 @@ from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import create_base_retry_decorator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
@ -121,63 +117,6 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_openai_messages(messages: List[dict]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
messages: List of dictionaries representing OpenAI messages
Returns:
List of LangChain BaseMessage objects.
"""
return [_convert_dict_to_message(m) for m in messages]
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class ChatOpenAI(BaseChatModel):
"""Wrapper around OpenAI Chat large language models.
@ -411,13 +350,13 @@ class ChatOpenAI(BaseChatModel):
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
message = convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
@ -568,7 +507,7 @@ class ChatOpenAI(BaseChatModel):
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@ -0,0 +1,107 @@
from typing import Any
import openai
import pytest
from langchain.adapters import openai as lcopenai
def _test_no_stream(**kwargs: Any) -> None:
result = openai.ChatCompletion.create(**kwargs)
lc_result = lcopenai.ChatCompletion.create(**kwargs)
if isinstance(lc_result, dict):
if isinstance(result, dict):
result_dict = result["choices"][0]["message"].to_dict_recursive()
lc_result_dict = lc_result["choices"][0]["message"]
assert result_dict == lc_result_dict
return
def _test_stream(**kwargs: Any) -> None:
result = []
for c in openai.ChatCompletion.create(**kwargs):
result.append(c["choices"][0]["delta"].to_dict_recursive())
lc_result = []
for c in lcopenai.ChatCompletion.create(**kwargs):
lc_result.append(c["choices"][0]["delta"])
assert result == lc_result
async def _test_async(**kwargs: Any) -> None:
result = await openai.ChatCompletion.acreate(**kwargs)
lc_result = await lcopenai.ChatCompletion.acreate(**kwargs)
if isinstance(lc_result, dict):
if isinstance(result, dict):
result_dict = result["choices"][0]["message"].to_dict_recursive()
lc_result_dict = lc_result["choices"][0]["message"]
assert result_dict == lc_result_dict
return
async def _test_astream(**kwargs: Any) -> None:
result = []
async for c in await openai.ChatCompletion.acreate(**kwargs):
result.append(c["choices"][0]["delta"].to_dict_recursive())
lc_result = []
async for c in await lcopenai.ChatCompletion.acreate(**kwargs):
lc_result.append(c["choices"][0]["delta"])
assert result == lc_result
FUNCTIONS = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
async def _test_module(**kwargs: Any) -> None:
_test_no_stream(**kwargs)
await _test_async(**kwargs)
_test_stream(stream=True, **kwargs)
await _test_astream(stream=True, **kwargs)
@pytest.mark.asyncio
async def test_normal_call() -> None:
await _test_module(
messages=[{"role": "user", "content": "hi"}],
model="gpt-3.5-turbo",
temperature=0,
)
@pytest.mark.asyncio
async def test_function_calling() -> None:
await _test_module(
messages=[{"role": "user", "content": "whats the weather in boston"}],
model="gpt-3.5-turbo",
functions=FUNCTIONS,
temperature=0,
)
@pytest.mark.asyncio
async def test_answer_with_function_calling() -> None:
await _test_module(
messages=[
{"role": "user", "content": "say hi, then whats the weather in boston"}
],
model="gpt-3.5-turbo",
functions=FUNCTIONS,
temperature=0,
)

View File

@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
import pytest
from langchain.adapters.openai import convert_dict_to_message
from langchain.chat_models.openai import (
ChatOpenAI,
_convert_dict_to_message,
)
from langchain.schema.messages import (
AIMessage,
@ -20,7 +20,7 @@ from langchain.schema.messages import (
def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
result = convert_dict_to_message(
{
"role": "function",
"name": name,
@ -34,21 +34,21 @@ def test_function_message_dict_to_function_message() -> None:
def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message)
result = convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
result = convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
result = convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output