diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index cce51bc4..d66b6405 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -163,7 +163,8 @@ class BaseOpenAI(BaseLLM, BaseModel): def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore """Initialize the OpenAI object.""" - if data.get("model_name", "").startswith("gpt-3.5-turbo"): + model_name = data.get("model_name", "") + if model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-4"): return OpenAIChat(**data) return super().__new__(cls)