langchain[minor]: add universal init_model (#22039)

decisions to discuss
- only chat models
- model_provider isn't based on any existing values like llm-type,
package names, class names
- implemented as function not as a wrapper ChatModel
- function name (init_model)
- in langchain as opposed to community or core
- marked beta
pull/22577/head
Bagatur 4 months ago committed by GitHub
parent 67012c2558
commit 1a911018bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,157 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cfdf4f09-8125-4ed1-8063-6feed57da8a3",
"metadata": {},
"source": [
"# How to let your end users choose their model\n",
"\n",
"Many LLM applications let end users specify what model provider and model they want the application to be powered by. This requires writing some logic to initialize different ChatModels based on some user configuration. The `init_chat_model()` helper method makes it easy to initialize a number of different model integrations without having to worry about import paths and class names.\n",
"\n",
":::tip Supported models\n",
"\n",
"See the [init_chat_model()](https://api.python.langchain.com/en/latest/chat_models/langchain.chat_models.base.init_chat_model.html) API reference for a full list of supported integrations.\n",
"\n",
"Make sure you have the integration packages installed for any model providers you want to support. E.g. you should have `langchain-openai` installed to init an OpenAI model.\n",
"\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "165b0de6-9ae3-4e3d-aa98-4fc8a97c4a06",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain langchain-openai langchain-anthropic langchain-google-vertexai"
]
},
{
"cell_type": "markdown",
"id": "ea2c9f57-a796-45f8-b6f4-3efd3f361a9b",
"metadata": {},
"source": [
"## Basic usage"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "79e14913-803c-4382-9009-5c6af3d75d35",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-4o: I'm an AI created by OpenAI, and I don't have a personal name. You can call me Assistant! How can I help you today?\n",
"\n",
"Claude Opus: My name is Claude. It's nice to meet you!\n",
"\n",
"Gemini 1.5: I am a large language model, trained by Google. I do not have a name. \n",
"\n",
"\n"
]
}
],
"source": [
"from langchain.chat_models import init_chat_model\n",
"\n",
"# Returns a langchain_openai.ChatOpenAI instance.\n",
"gpt_4o = init_chat_model(\"gpt-4o\", model_provider=\"openai\", temperature=0)\n",
"# Returns a langchain_anthropic.ChatAnthropic instance.\n",
"claude_opus = init_chat_model(\n",
" \"claude-3-opus-20240229\", model_provider=\"anthropic\", temperature=0\n",
")\n",
"# Returns a langchain_google_vertexai.ChatVertexAI instance.\n",
"gemini_15 = init_chat_model(\n",
" \"gemini-1.5-pro\", model_provider=\"google_vertexai\", temperature=0\n",
")\n",
"\n",
"# Since all model integrations implement the ChatModel interface, you can use them in the same way.\n",
"print(\"GPT-4o: \" + gpt_4o.invoke(\"what's your name\").content + \"\\n\")\n",
"print(\"Claude Opus: \" + claude_opus.invoke(\"what's your name\").content + \"\\n\")\n",
"print(\"Gemini 1.5: \" + gemini_15.invoke(\"what's your name\").content + \"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "fff9a4c8-b6ee-4a1a-8d3d-0ecaa312d4ed",
"metadata": {},
"source": [
"## Simple config example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75c25d39-bf47-4b51-a6c6-64d9c572bfd6",
"metadata": {},
"outputs": [],
"source": [
"user_config = {\n",
" \"model\": \"...user-specified...\",\n",
" \"model_provider\": \"...user-specified...\",\n",
" \"temperature\": 0,\n",
" \"max_tokens\": 1000,\n",
"}\n",
"\n",
"llm = init_chat_model(**user_config)\n",
"llm.invoke(\"what's your name\")"
]
},
{
"cell_type": "markdown",
"id": "f811f219-5e78-4b62-b495-915d52a22532",
"metadata": {},
"source": [
"## Inferring model provider\n",
"\n",
"For common and distinct model names `init_chat_model()` will attempt to infer the model provider. See the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain.chat_models.base.init_chat_model.html) for a full list of inference behavior. E.g. any model that starts with `gpt-3...` or `gpt-4...` will be inferred as using model provider `openai`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0378ccc6-95bc-4d50-be50-fccc193f0a71",
"metadata": {},
"outputs": [],
"source": [
"gpt_4o = init_chat_model(\"gpt-4o\", temperature=0)\n",
"claude_opus = init_chat_model(\"claude-3-opus-20240229\", temperature=0)\n",
"gemini_15 = init_chat_model(\"gemini-1.5-pro\", temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da07b5c0-d2e6-42e4-bfcd-2efcfaae6221",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-2",
"language": "python",
"name": "poetry-venv-2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -78,6 +78,7 @@ Chat Models are newer forms of language models that take messages in and output
- [How to: stream a response back](/docs/how_to/chat_streaming)
- [How to: track token usage](/docs/how_to/chat_token_usage_tracking)
- [How to: track response metadata across providers](/docs/how_to/response_metadata)
- [How to: let your end users choose their model](/docs/how_to/chat_models_universal_init/)
### LLMs

@ -21,6 +21,7 @@ import warnings
from langchain_core._api import LangChainDeprecationWarning
from langchain._api.interactive_env import is_interactive_env
from langchain.chat_models.base import init_chat_model
def __getattr__(name: str) -> None:
@ -41,6 +42,7 @@ def __getattr__(name: str) -> None:
__all__ = [
"init_chat_model",
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",

@ -1,3 +1,7 @@
from importlib import util
from typing import Any, Optional
from langchain_core._api import beta
from langchain_core.language_models.chat_models import (
BaseChatModel,
SimpleChatModel,
@ -10,4 +14,189 @@ __all__ = [
"SimpleChatModel",
"generate_from_stream",
"agenerate_from_stream",
"init_chat_model",
]
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
@beta()
def init_chat_model(
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
) -> BaseChatModel:
"""Initialize a ChatModel from the model name and provider.
Must have the integration package corresponding to the model provider installed.
Args:
model: The name of the model, e.g. "gpt-4o", "claude-3-opus-20240229".
model_provider: The model provider. Supported model_provider values and the
corresponding integration package:
- openai (langchain-openai)
- anthropic (langchain-anthropic)
- azure_openai (langchain-openai)
- google_vertexai (langchain-google-vertexai)
- google_genai (langchain-google-genai)
- bedrock (langchain-aws)
- cohere (langchain-cohere)
- fireworks (langchain-fireworks)
- together (langchain-together)
- mistralai (langchain-mistralai)
- huggingface (langchain-huggingface)
- groq (langchain-groq)
- ollama (langchain-community)
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- gpt-3... or gpt-4... -> openai
- claude... -> anthropic
- amazon.... -> bedrock
- gemini... -> google_vertexai
- command... -> cohere
- accounts/fireworks... -> fireworks
kwargs: Additional keyword args to pass to
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
Returns:
The BaseChatModel corresponding to the model_name and model_provider specified.
Raises:
ValueError: If model_provider cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
Example:
.. code-block:: python
from langchain.chat_models import init_chat_model
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
claude_opus = init_chat_model("claude-3-opus-20240229", model_provider="anthropic", temperature=0)
gemini_15 = init_chat_model("gemini-1.5-pro", model_provider="google_vertexai", temperature=0)
gpt_4o.invoke("what's your name")
claude_opus.invoke("what's your name")
gemini_15.invoke("what's your name")
""" # noqa: E501
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
raise ValueError(
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
model_provider = model_provider.replace("-", "_").lower()
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
elif model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs)
elif model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
elif model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
elif model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
elif model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
elif model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
elif model_provider == "ollama":
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
return ChatOllama(model=model, **kwargs)
elif model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
elif model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs)
elif model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace(model_id=model, **kwargs)
elif model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
elif model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
# TODO: update to use model= once ChatBedrock supports
return ChatBedrock(model_id=model, **kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError(
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
f"{supported}"
)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
}
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if model_name.startswith("gpt-3") or model_name.startswith("gpt-4"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("command"):
return "cohere"
elif model_name.startswith("accounts/fireworks"):
return "fireworks"
elif model_name.startswith("gemini"):
return "google_vertexai"
elif model_name.startswith("amazon."):
return "bedrock"
else:
return None
def _check_pkg(pkg: str) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-")
raise ImportError(
f"Unable to import {pkg_kebab}. Please install with "
f"`pip install -U {pkg_kebab}`"
)

File diff suppressed because it is too large Load Diff

@ -22,15 +22,11 @@ PyYAML = ">=5.3"
numpy = "^1"
aiohttp = "^3.8.3"
tenacity = "^8.1.0"
async-timeout = {version = "^4.0.0", python = "<3.11"}
azure-core = {version = "^1.26.4", optional=true}
tqdm = {version = ">=4.48.0", optional = true}
openapi-pydantic = {version = "^0.3.2", optional = true}
faiss-cpu = {version = "^1", optional = true}
manifest-ml = {version = "^0.0.1", optional = true}
transformers = {version = "^4", optional = true}
beautifulsoup4 = {version = "^4", optional = true}
torch = {version = ">=1,<3", optional = true}
jinja2 = {version = "^3", optional = true}
tiktoken = {version = ">=0.7,<1.0", optional = true, python=">=3.9"}
qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
cohere = {version = ">=4,<6", optional = true}
@ -38,77 +34,28 @@ openai = {version = "<2", optional = true}
nlpcloud = {version = "^1", optional = true}
huggingface_hub = {version = "^0", optional = true}
sentence-transformers = {version = "^2", optional = true}
arxiv = {version = "^1.4", optional = true}
pypdf = {version = "^3.4.0", optional = true}
aleph-alpha-client = {version="^2.15.0", optional = true}
pgvector = {version = "^0.1.6", optional = true}
async-timeout = {version = "^4.0.0", python = "<3.11"}
azure-identity = {version = "^1.12.0", optional=true}
atlassian-python-api = {version = "^3.36.0", optional=true}
html2text = {version="^2020.1.16", optional=true}
numexpr = {version="^2.8.6", optional=true}
azure-cosmos = {version="^4.4.0b1", optional=true}
jq = {version = "^1.4.1", optional = true}
pdfminer-six = {version = "^20221105", optional = true}
docarray = {version="^0.32.0", extras=["hnswlib"], optional=true}
lxml = {version = ">=4.9.3,<6.0", optional = true}
pymupdf = {version = "^1.22.3", optional = true}
rapidocr-onnxruntime = {version = "^1.3.2", optional = true, python = ">=3.8.1,<3.12"}
pypdfium2 = {version = "^4.10.0", optional = true}
gql = {version = "^3.4.1", optional = true}
pandas = {version = "^2.0.1", optional = true}
telethon = {version = "^1.28.5", optional = true}
chardet = {version="^5.1.0", optional=true}
requests-toolbelt = {version = "^1.0.0", optional = true}
openlm = {version = "^0.0.5", optional = true}
scikit-learn = {version = "^1.2.2", optional = true}
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
py-trello = {version = "^0.19.0", optional = true}
bibtexparser = {version = "^1.4.0", optional = true}
pyspark = {version = "^3.4.0", optional = true}
clarifai = {version = ">=9.1.0", optional = true}
mwparserfromhell = {version = "^0.6.4", optional = true}
mwxml = {version = "^0.3.3", optional = true}
azure-search-documents = {version = "11.4.0b8", optional = true}
esprima = {version = "^4.0.1", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
psychicapi = {version = "^0.8.0", optional = true}
cassio = {version = "^0.1.0", optional = true}
sympy = {version = "^1.12", optional = true}
rapidfuzz = {version = "^3.1.1", optional = true}
jsonschema = {version = ">1", optional = true}
rank-bm25 = {version = "^0.2.2", optional = true}
geopandas = {version = "^0.13.1", optional = true}
gitpython = {version = "^3.1.32", optional = true}
feedparser = {version = "^6.0.10", optional = true}
newspaper3k = {version = "^0.2.8", optional = true}
xata = {version = "^1.0.0a7", optional = true}
xmltodict = {version = "^0.13.0", optional = true}
markdownify = {version = "^0.11.6", optional = true}
assemblyai = {version = "^0.17.0", optional = true}
dashvector = {version = "^1.0.1", optional = true}
sqlite-vss = {version = "^0.1.2", optional = true}
motor = {version = "^3.3.1", optional = true}
timescale-vector = {version = "^0.0.1", optional = true}
typer = {version= "^0.9.0", optional = true}
anthropic = {version = "^0.3.11", optional = true}
aiosqlite = {version = "^0.19.0", optional = true}
rspace_client = {version = "^2.5.0", optional = true}
upstash-redis = {version = "^0.15.0", optional = true}
azure-ai-textanalytics = {version = "^5.3.0", optional = true}
google-cloud-documentai = {version = "^2.20.1", optional = true}
fireworks-ai = {version = "^0.9.0", optional = true}
javelin-sdk = {version = "^0.1.8", optional = true}
hologres-vector = {version = "^0.0.6", optional = true}
praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
couchbase = {version = "^4.1.9", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
datasets = {version = "^2.15.0", optional = true}
langchain-openai = {version = "^0.1", optional = true}
rdflib = {version = "7.0.0", optional = true}
langchain-openai = {version = "^0", optional = true}
langchain-anthropic = {version = "^0", optional = true}
langchain-fireworks = {version = "^0", optional = true}
langchain-together = {version = "^0", optional = true}
langchain-mistralai = {version = "^0", optional = true}
langchain-groq = {version = "^0", optional = true}
jsonschema = {version = "^4.22.0", optional = true}
[tool.poetry.group.test]
optional = true
@ -162,11 +109,8 @@ optional = true
# https://python.langchain.com/docs/contributing/code#working-with-optional-dependencies
pytest-vcr = "^1.0.2"
wrapt = "^1.15.0"
openai = "^1"
python-dotenv = "^1.0.0"
cassio = "^0.1.0"
tiktoken = ">=0.7,<1"
anthropic = "^0.3.11"
langchain-core = {path = "../core", develop = true}
langchain-text-splitters = {path = "../text-splitters", develop = true}
langchainhub = "^0.1.16"
@ -229,74 +173,18 @@ cli = ["typer"]
# Please use new-line on formatting to make it easier to add new packages without
# merge-conflicts
extended_testing = [
"aleph-alpha-client",
"aiosqlite",
"assemblyai",
"beautifulsoup4",
"bibtexparser",
"cassio",
"chardet",
"datasets",
"google-cloud-documentai",
"esprima",
"jq",
"pdfminer-six",
"pgvector",
"pypdf",
"pymupdf",
"pypdfium2",
"tqdm",
"lxml",
"atlassian-python-api",
"mwparserfromhell",
"mwxml",
"msal",
"pandas",
"telethon",
"psychicapi",
"gql",
"requests-toolbelt",
"html2text",
"numexpr",
"py-trello",
"scikit-learn",
"streamlit",
"pyspark",
"openai",
"sympy",
"rapidfuzz",
"jsonschema",
"openai",
"rank-bm25",
"geopandas",
"jinja2",
"gitpython",
"newspaper3k",
"feedparser",
"xata",
"xmltodict",
"faiss-cpu",
"openapi-pydantic",
"markdownify",
"arxiv",
"dashvector",
"sqlite-vss",
"rapidocr-onnxruntime",
"motor",
"timescale-vector",
"anthropic",
"upstash-redis",
"rspace_client",
"fireworks-ai",
"javelin-sdk",
"hologres-vector",
"praw",
"databricks-vectorsearch",
"couchbase",
"dgml-utils",
"cohere",
"langchain-openai",
"rdflib",
"langchain-openai",
"langchain-anthropic",
"langchain-fireworks",
"langchain-together",
"langchain-mistralai",
"langchain-groq",
"openai",
"tiktoken",
"numexpr",
"rapidfuzz",
"aiosqlite",
"jsonschema",
]
[tool.ruff]

@ -1,12 +1,47 @@
from langchain.chat_models.base import __all__
import pytest
from langchain.chat_models.base import __all__, init_chat_model
EXPECTED_ALL = [
"BaseChatModel",
"SimpleChatModel",
"agenerate_from_stream",
"generate_from_stream",
"init_chat_model",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
@pytest.mark.requires(
"langchain_openai",
"langchain_anthropic",
"langchain_fireworks",
"langchain_together",
"langchain_mistralai",
"langchain_groq",
)
@pytest.mark.parametrize(
["model_name", "model_provider"],
[
("gpt-4o", "openai"),
("claude-3-opus-20240229", "anthropic"),
("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
("meta-llama/Llama-3-8b-chat-hf", "together"),
("mixtral-8x7b-32768", "groq"),
],
)
def test_init_chat_model(model_name: str, model_provider: str) -> None:
init_chat_model(model_name, model_provider=model_provider, api_key="foo")
def test_init_missing_dep() -> None:
with pytest.raises(ImportError):
init_chat_model("gpt-4o", model_provider="openai")
def test_init_unknown_provider() -> None:
with pytest.raises(ValueError):
init_chat_model("foo", model_provider="bar")

@ -1,6 +1,7 @@
from langchain import chat_models
EXPECTED_ALL = [
"init_chat_model",
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",

Loading…
Cancel
Save