multiple: standard chat model tests (#20359)

This commit is contained in:
Erick Friis 2024-04-11 18:23:13 -07:00 committed by GitHub
parent f78564d75c
commit e6806a08d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 237 additions and 11 deletions

View File

@ -278,7 +278,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.1.37" version = "0.1.42"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -291,7 +291,6 @@ langsmith = "^0.1.0"
packaging = "^23.2" packaging = "^23.2"
pydantic = ">=1,<3" pydantic = ">=1,<3"
PyYAML = ">=5.3" PyYAML = ">=5.3"
requests = "^2"
tenacity = "^8.1.0" tenacity = "^8.1.0"
[package.extras] [package.extras]
@ -301,6 +300,23 @@ extended-testing = ["jinja2 (>=3,<4)"]
type = "directory" type = "directory"
url = "../../core" url = "../../core"
[[package]]
name = "langchain-standard-tests"
version = "0.1.0"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.8.1,<4.0"
files = []
develop = true
[package.dependencies]
langchain-core = "^0.1.40"
pytest = ">=7,<9"
[package.source]
type = "directory"
url = "../../standard-tests"
[[package]] [[package]]
name = "langchain-text-splitters" name = "langchain-text-splitters"
version = "0.0.1" version = "0.0.1"
@ -994,4 +1010,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "d0e6ec94729c40ea458eead2404d2b501f18dd67c211c832146352410223276e" content-hash = "446d53423f89a3a378db9cff7ce8cb392e146e294d31e9d1bbfc23108a571097"

View File

@ -22,6 +22,7 @@ syrupy = "^4.0.2"
pytest-watcher = "^0.3.4" pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true } langchain-core = { path = "../../core", develop = true }
langchain-standard-tests = {path = "../../standard-tests", develop = true}
[tool.poetry.group.codespell] [tool.poetry.group.codespell]
optional = true optional = true

View File

@ -0,0 +1,21 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_ai21 import ChatAI21
class TestAI21Standard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
}

View File

@ -0,0 +1,22 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_ai21 import ChatAI21
class TestAI21Standard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
"api_key": "test_api_key",
}

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -460,6 +460,23 @@ extended-testing = ["jinja2 (>=3,<4)"]
type = "directory" type = "directory"
url = "../../core" url = "../../core"
[[package]]
name = "langchain-standard-tests"
version = "0.1.0"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.8.1,<4.0"
files = []
develop = true
[package.dependencies]
langchain-core = "^0.1.40"
pytest = ">=7,<9"
[package.source]
type = "directory"
url = "../../standard-tests"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.42" version = "0.1.42"
@ -1206,4 +1223,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "dce6101bb01cd9ab521a8b8df906db7b4631d8305952a442c78db9a77f2b2f1f" content-hash = "adc0024beed52fca8f0f9e1786aaf5c25b3bf6fd138fb9668b44ef36c6bcf23f"

View File

@ -28,6 +28,7 @@ pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true } langchain-core = { path = "../../core", develop = true }
defusedxml = "^0.7.1" defusedxml = "^0.7.1"
langchain-standard-tests = {path = "../../standard-tests", develop = true}
[tool.poetry.group.codespell] [tool.poetry.group.codespell]
optional = true optional = true

View File

@ -0,0 +1,21 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_anthropic import ChatAnthropic
class TestAnthropicStandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAnthropic
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "claude-3-haiku-20240307",
}

View File

@ -0,0 +1,21 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_anthropic import ChatAnthropic
class TestAnthropicStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAnthropic
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "claude-3-haiku-20240307",
}

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
@ -594,6 +594,23 @@ extended-testing = ["jinja2 (>=3,<4)"]
type = "directory" type = "directory"
url = "../../core" url = "../../core"
[[package]]
name = "langchain-standard-tests"
version = "0.1.0"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.8.1,<4.0"
files = []
develop = true
[package.dependencies]
langchain-core = "^0.1.40"
pytest = ">=7,<9"
[package.source]
type = "directory"
url = "../../standard-tests"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.10" version = "0.1.10"
@ -1536,4 +1553,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "5f4c474fcbc2ef84c95a0e4992c0acdb08e4fa5817909576d31caf793e54109b" content-hash = "fbf305613a6134e08c9efec406928b30ba7830a13a87c9a523708699b7efc9a3"

View File

@ -29,6 +29,7 @@ syrupy = "^4.0.2"
pytest-watcher = "^0.3.4" pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true } langchain-core = { path = "../../core", develop = true }
langchain-standard-tests = {path = "../../standard-tests", develop = true}
[tool.poetry.group.codespell] [tool.poetry.group.codespell]
optional = true optional = true

View File

@ -0,0 +1,15 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatFireworks

View File

@ -0,0 +1,21 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatFireworks
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"api_key": "test_api_key",
}

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -345,6 +345,23 @@ extended-testing = ["jinja2 (>=3,<4)"]
type = "directory" type = "directory"
url = "../../core" url = "../../core"
[[package]]
name = "langchain-standard-tests"
version = "0.1.0"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.8.1,<4.0"
files = []
develop = true
[package.dependencies]
langchain-core = "^0.1.40"
pytest = ">=7,<9"
[package.source]
type = "directory"
url = "../../standard-tests"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.4" version = "0.1.4"
@ -867,4 +884,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "f76983ba51b9fc343a68563a9f2d847efe9c5856bbeff5796d9146c521221a53" content-hash = "1692a375c2817216876453275294e5aa2500364b7e36ae2b4b0ec1fe1837402e"

View File

@ -24,6 +24,7 @@ pytest-mock = "^3.10.0"
pytest-watcher = "^0.3.4" pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true } langchain-core = { path = "../../core", develop = true }
langchain-standard-tests = {path = "../../standard-tests", develop = true}
[tool.poetry.group.codespell] [tool.poetry.group.codespell]
optional = true optional = true

View File

@ -0,0 +1,15 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_groq import ChatGroq
class TestMistralStandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq

View File

@ -0,0 +1,15 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_groq import ChatGroq
class TestGroqStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq

View File

@ -38,7 +38,10 @@ class ChatModelUnitTests(ABC):
def chat_model_has_structured_output( def chat_model_has_structured_output(
self, chat_model_class: Type[BaseChatModel] self, chat_model_class: Type[BaseChatModel]
) -> bool: ) -> bool:
return hasattr(chat_model_class, "with_structured_output") return (
chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output
)
def test_chat_model_init( def test_chat_model_init(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
@ -49,7 +52,8 @@ class ChatModelUnitTests(ABC):
def test_chat_model_init_api_key( def test_chat_model_init_api_key(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None: ) -> None:
model = chat_model_class(api_key="test", **chat_model_params) # type: ignore params = {**chat_model_params, "api_key": "test"}
model = chat_model_class(**params) # type: ignore
assert model is not None assert model is not None
def test_chat_model_init_streaming( def test_chat_model_init_streaming(