From 499e76b1996787f714a020917a58a4be0d2896ac Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 2 Mar 2023 17:04:18 +0000 Subject: [PATCH] Allow the regular openai class to be used for ChatGPT models (#1393) Co-authored-by: Harrison Chase --- .gitignore | 1 + langchain/llms/openai.py | 6 ++++++ pyproject.toml | 2 +- tests/integration_tests/llms/test_openai.py | 7 +++++++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2a7a13ed..0b7bffc8 100644 --- a/.gitignore +++ b/.gitignore @@ -106,6 +106,7 @@ celerybeat.pid # Environments .env +.envrc .venv .venvs env/ diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index f0646c46..5d52fd3a 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -161,6 +161,12 @@ class BaseOpenAI(BaseLLM, BaseModel): streaming: bool = False """Whether to stream the results or not.""" + def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore + """Initialize the OpenAI object.""" + if data.get("model_name", "").startswith("gpt-3.5-turbo"): + return OpenAIChat(**data) + return super().__new__(cls) + class Config: """Configuration for this pydantic object.""" diff --git a/pyproject.toml b/pyproject.toml index e32e2830..e2e32bdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.0.99" +version = "0.0.100" description = "Building applications with LLMs through composability" authors = [] license = "MIT" diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index d5f7e61b..818068b6 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -144,6 +144,13 @@ async def test_openai_async_streaming_callback() -> None: assert isinstance(result, LLMResult) +def test_openai_chat_wrong_class() -> None: + """Test OpenAIChat with wrong class still works.""" + llm = OpenAI(model_name="gpt-3.5-turbo") + output = llm("Say foo:") + assert isinstance(output, str) + + def test_openai_chat() -> None: """Test OpenAIChat.""" llm = OpenAIChat(max_tokens=10)