langchain-mistralai - make base URL possible to set via env variable for ChatMistralAI (#25956)

Thank you for contributing to LangChain!


**Description:** 

Similar to other packages (`langchain_openai`, `langchain_anthropic`) it
would be beneficial if that `ChatMistralAI` model could fetch the API
base URL from the environment.

This PR allows this via the following order:
- provided value
- then whatever `MISTRAL_API_URL` is set to
- then whatever `MISTRAL_BASE_URL` is set to
- if `None`, then default is ` "https://api.mistral.com/v1"`


- [x] **Add tests and docs**:

Added unit tests, docs I feel are unnecessary, as this is just aligning
with other packages that do the same?


- [x] **Lint and test**: 

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Maximilian Schulz 2024-09-03 16:32:35 +02:00 committed by GitHub
parent c7154a4045
commit fdeaff4149
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 3 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import hashlib
import json
import logging
import os
import re
import uuid
from operator import itemgetter
@ -364,7 +365,7 @@ class ChatMistralAI(BaseChatModel):
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
)
endpoint: str = "https://api.mistral.ai/v1"
endpoint: Optional[str] = Field(default=None, alias="base_url")
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
@ -472,10 +473,17 @@ class ChatMistralAI(BaseChatModel):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
base_url_str = (
values.get("endpoint")
or os.environ.get("MISTRAL_BASE_URL")
or "https://api.mistral.ai/v1"
)
values["endpoint"] = base_url_str
if not values.get("client"):
values["client"] = httpx.Client(
base_url=values["endpoint"],
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
@ -486,7 +494,7 @@ class ChatMistralAI(BaseChatModel):
# todo: handle retries and max_concurrency
if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",

View File

@ -44,6 +44,39 @@ def test_mistralai_initialization() -> None:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
@pytest.mark.parametrize(
"model,expected_url",
[
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]
],
)
def test_mistralai_initialization_baseurl(
model: ChatMistralAI, expected_url: str
) -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized providing endpoint, but also
# with default
assert model.endpoint == expected_url
@pytest.mark.parametrize(
"env_var_name",
[
("MISTRAL_BASE_URL"),
],
)
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using env variable
import os
os.environ[env_var_name] = "boo"
model = ChatMistralAI(model="test") # type: ignore[call-arg]
assert model.endpoint == "boo"
@pytest.mark.parametrize(
("message", "expected"),
[