From d8952b8e8c1ee824f2bce27988d401bfcfd96779 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 16 Sep 2024 17:35:54 -0700 Subject: [PATCH] langchain[patch]: infer mistral provider in init_chat_model (#26557) --- libs/langchain/langchain/chat_models/base.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index f2988715a7..6c66a6f347 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -109,6 +109,7 @@ def init_chat_model( - google_vertexai (langchain-google-vertexai) - google_genai (langchain-google-genai) - bedrock (langchain-aws) + - bedrock_converse (langchain-aws) - cohere (langchain-cohere) - fireworks (langchain-fireworks) - together (langchain-together) @@ -120,12 +121,13 @@ def init_chat_model( Will attempt to infer model_provider from model if not specified. The following providers will be inferred based on these model prefixes: - - gpt-3... or gpt-4... -> openai + - gpt-3..., gpt-4..., or o1... -> openai - claude... -> anthropic - amazon.... -> bedrock - gemini... -> google_vertexai - command... -> cohere - accounts/fireworks... -> fireworks + - mistral... -> mistralai configurable_fields: Which model parameters are configurable: @@ -276,8 +278,13 @@ def init_chat_model( .. versionchanged:: 0.2.12 - Support for Ollama via langchain-ollama package added. Previously - langchain-community version of Ollama (now deprecated) was installed by default. + Support for ChatOllama via langchain-ollama package added + (langchain_ollama.ChatOllama). Previously, + the now-deprecated langchain-community version of Ollama was imported + (langchain_community.chat_models.ChatOllama). + + Support for langchain_aws.ChatBedrockConverse added + (model_provider="bedrock_converse"). """ # noqa: E501 if not model and not configurable_fields: @@ -424,7 +431,7 @@ _SUPPORTED_PROVIDERS = { def _attempt_infer_model_provider(model_name: str) -> Optional[str]: - if model_name.startswith("gpt-3") or model_name.startswith("gpt-4"): + if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1")): return "openai" elif model_name.startswith("claude"): return "anthropic" @@ -436,6 +443,8 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]: return "google_vertexai" elif model_name.startswith("amazon."): return "bedrock" + elif model_name.startswith("mistral"): + return "mistralai" else: return None