mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
langchain-mistralai
- make base URL possible to set via env variable for ChatMistralAI
(#25956)
Thank you for contributing to LangChain! **Description:** Similar to other packages (`langchain_openai`, `langchain_anthropic`) it would be beneficial if that `ChatMistralAI` model could fetch the API base URL from the environment. This PR allows this via the following order: - provided value - then whatever `MISTRAL_API_URL` is set to - then whatever `MISTRAL_BASE_URL` is set to - if `None`, then default is ` "https://api.mistral.com/v1"` - [x] **Add tests and docs**: Added unit tests, docs I feel are unnecessary, as this is just aligning with other packages that do the same? - [x] **Lint and test**: Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
c7154a4045
commit
fdeaff4149
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
@ -364,7 +365,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
||||
)
|
||||
endpoint: str = "https://api.mistral.ai/v1"
|
||||
endpoint: Optional[str] = Field(default=None, alias="base_url")
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
@ -472,10 +473,17 @@ class ChatMistralAI(BaseChatModel):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
|
||||
# todo: handle retries
|
||||
base_url_str = (
|
||||
values.get("endpoint")
|
||||
or os.environ.get("MISTRAL_BASE_URL")
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
values["endpoint"] = base_url_str
|
||||
if not values.get("client"):
|
||||
values["client"] = httpx.Client(
|
||||
base_url=values["endpoint"],
|
||||
base_url=base_url_str,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
@ -486,7 +494,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
# todo: handle retries and max_concurrency
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
base_url=values["endpoint"],
|
||||
base_url=base_url_str,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
|
@ -44,6 +44,39 @@ def test_mistralai_initialization() -> None:
|
||||
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected_url",
|
||||
[
|
||||
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
|
||||
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]
|
||||
],
|
||||
)
|
||||
def test_mistralai_initialization_baseurl(
|
||||
model: ChatMistralAI, expected_url: str
|
||||
) -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized providing endpoint, but also
|
||||
# with default
|
||||
|
||||
assert model.endpoint == expected_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("MISTRAL_BASE_URL"),
|
||||
],
|
||||
)
|
||||
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized using env variable
|
||||
import os
|
||||
|
||||
os.environ[env_var_name] = "boo"
|
||||
model = ChatMistralAI(model="test") # type: ignore[call-arg]
|
||||
assert model.endpoint == "boo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "expected"),
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user