langchain[patch]: infer mistral provider in init_chat_model (#26557)

This commit is contained in:
Bagatur 2024-09-16 17:35:54 -07:00 committed by GitHub
parent 31f61d4d7d
commit d8952b8e8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -109,6 +109,7 @@ def init_chat_model(
- google_vertexai (langchain-google-vertexai)
- google_genai (langchain-google-genai)
- bedrock (langchain-aws)
- bedrock_converse (langchain-aws)
- cohere (langchain-cohere)
- fireworks (langchain-fireworks)
- together (langchain-together)
@ -120,12 +121,13 @@ def init_chat_model(
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- gpt-3... or gpt-4... -> openai
- gpt-3..., gpt-4..., or o1... -> openai
- claude... -> anthropic
- amazon.... -> bedrock
- gemini... -> google_vertexai
- command... -> cohere
- accounts/fireworks... -> fireworks
- mistral... -> mistralai
configurable_fields: Which model parameters are
configurable:
@ -276,8 +278,13 @@ def init_chat_model(
.. versionchanged:: 0.2.12
Support for Ollama via langchain-ollama package added. Previously
langchain-community version of Ollama (now deprecated) was installed by default.
Support for ChatOllama via langchain-ollama package added
(langchain_ollama.ChatOllama). Previously,
the now-deprecated langchain-community version of Ollama was imported
(langchain_community.chat_models.ChatOllama).
Support for langchain_aws.ChatBedrockConverse added
(model_provider="bedrock_converse").
""" # noqa: E501
if not model and not configurable_fields:
@ -424,7 +431,7 @@ _SUPPORTED_PROVIDERS = {
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if model_name.startswith("gpt-3") or model_name.startswith("gpt-4"):
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1")):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
@ -436,6 +443,8 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return "google_vertexai"
elif model_name.startswith("amazon."):
return "bedrock"
elif model_name.startswith("mistral"):
return "mistralai"
else:
return None