mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
198 lines
6.5 KiB
Python
198 lines
6.5 KiB
Python
|
"""Test Friendli LLM for chat."""
|
||
|
|
||
|
from unittest.mock import AsyncMock, MagicMock, Mock
|
||
|
|
||
|
import pytest
|
||
|
from langchain_core.pydantic_v1 import SecretStr
|
||
|
from pytest import CaptureFixture, MonkeyPatch
|
||
|
|
||
|
from langchain_community.adapters.openai import aenumerate
|
||
|
from langchain_community.chat_models import ChatFriendli
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_friendli_client() -> Mock:
|
||
|
"""Mock instance of Friendli client."""
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_friendli_async_client() -> AsyncMock:
|
||
|
"""Mock instance of Friendli async client."""
|
||
|
return AsyncMock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def chat_friendli(
|
||
|
mock_friendli_client: Mock, mock_friendli_async_client: AsyncMock
|
||
|
) -> ChatFriendli:
|
||
|
"""Friendli LLM for chat with mock clients."""
|
||
|
return ChatFriendli(
|
||
|
friendli_token=SecretStr("personal-access-token"),
|
||
|
client=mock_friendli_client,
|
||
|
async_client=mock_friendli_async_client,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.requires("friendli")
|
||
|
def test_friendli_token_is_secret_string(capsys: CaptureFixture) -> None:
|
||
|
"""Test if friendli token is stored as a SecretStr."""
|
||
|
fake_token_value = "personal-access-token"
|
||
|
chat = ChatFriendli(friendli_token=fake_token_value)
|
||
|
assert isinstance(chat.friendli_token, SecretStr)
|
||
|
assert chat.friendli_token.get_secret_value() == fake_token_value
|
||
|
print(chat.friendli_token, end="") # noqa: T201
|
||
|
captured = capsys.readouterr()
|
||
|
assert captured.out == "**********"
|
||
|
|
||
|
|
||
|
@pytest.mark.requires("friendli")
|
||
|
def test_friendli_token_read_from_env(
|
||
|
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||
|
) -> None:
|
||
|
"""Test if friendli token can be parsed from environment."""
|
||
|
fake_token_value = "personal-access-token"
|
||
|
monkeypatch.setenv("FRIENDLI_TOKEN", fake_token_value)
|
||
|
chat = ChatFriendli()
|
||
|
assert isinstance(chat.friendli_token, SecretStr)
|
||
|
assert chat.friendli_token.get_secret_value() == fake_token_value
|
||
|
print(chat.friendli_token, end="") # noqa: T201
|
||
|
captured = capsys.readouterr()
|
||
|
assert captured.out == "**********"
|
||
|
|
||
|
|
||
|
@pytest.mark.requires("friendli")
|
||
|
def test_friendli_invoke(
|
||
|
mock_friendli_client: Mock, chat_friendli: ChatFriendli
|
||
|
) -> None:
|
||
|
"""Test invocation with friendli."""
|
||
|
mock_message = Mock()
|
||
|
mock_message.content = "Hello Friendli"
|
||
|
mock_message.role = "assistant"
|
||
|
mock_choice = Mock()
|
||
|
mock_choice.message = mock_message
|
||
|
mock_response = Mock()
|
||
|
mock_response.choices = [mock_choice]
|
||
|
mock_friendli_client.chat.completions.create.return_value = mock_response
|
||
|
|
||
|
result = chat_friendli.invoke("Hello langchain")
|
||
|
assert result.content == "Hello Friendli"
|
||
|
mock_friendli_client.chat.completions.create.assert_called_once_with(
|
||
|
messages=[{"role": "user", "content": "Hello langchain"}],
|
||
|
stream=False,
|
||
|
model=chat_friendli.model,
|
||
|
frequency_penalty=None,
|
||
|
presence_penalty=None,
|
||
|
max_tokens=None,
|
||
|
stop=None,
|
||
|
temperature=None,
|
||
|
top_p=None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.requires("friendli")
|
||
|
async def test_friendli_ainvoke(
|
||
|
mock_friendli_async_client: AsyncMock, chat_friendli: ChatFriendli
|
||
|
) -> None:
|
||
|
"""Test async invocation with friendli."""
|
||
|
mock_message = Mock()
|
||
|
mock_message.content = "Hello Friendli"
|
||
|
mock_message.role = "assistant"
|
||
|
mock_choice = Mock()
|
||
|
mock_choice.message = mock_message
|
||
|
mock_response = Mock()
|
||
|
mock_response.choices = [mock_choice]
|
||
|
mock_friendli_async_client.chat.completions.create.return_value = mock_response
|
||
|
|
||
|
result = await chat_friendli.ainvoke("Hello langchain")
|
||
|
assert result.content == "Hello Friendli"
|
||
|
mock_friendli_async_client.chat.completions.create.assert_awaited_once_with(
|
||
|
messages=[{"role": "user", "content": "Hello langchain"}],
|
||
|
stream=False,
|
||
|
model=chat_friendli.model,
|
||
|
frequency_penalty=None,
|
||
|
presence_penalty=None,
|
||
|
max_tokens=None,
|
||
|
stop=None,
|
||
|
temperature=None,
|
||
|
top_p=None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.requires("friendli")
|
||
|
def test_friendli_stream(
|
||
|
mock_friendli_client: Mock, chat_friendli: ChatFriendli
|
||
|
) -> None:
|
||
|
"""Test stream with friendli."""
|
||
|
mock_delta_0 = Mock()
|
||
|
mock_delta_0.content = "Hello "
|
||
|
mock_delta_1 = Mock()
|
||
|
mock_delta_1.content = "Friendli"
|
||
|
mock_choice_0 = Mock()
|
||
|
mock_choice_0.delta = mock_delta_0
|
||
|
mock_choice_1 = Mock()
|
||
|
mock_choice_1.delta = mock_delta_1
|
||
|
mock_chunk_0 = Mock()
|
||
|
mock_chunk_0.choices = [mock_choice_0]
|
||
|
mock_chunk_1 = Mock()
|
||
|
mock_chunk_1.choices = [mock_choice_1]
|
||
|
mock_stream = MagicMock()
|
||
|
mock_chunks = [mock_chunk_0, mock_chunk_1]
|
||
|
mock_stream.__iter__.return_value = mock_chunks
|
||
|
|
||
|
mock_friendli_client.chat.completions.create.return_value = mock_stream
|
||
|
stream = chat_friendli.stream("Hello langchain")
|
||
|
for i, chunk in enumerate(stream):
|
||
|
assert chunk.content == mock_chunks[i].choices[0].delta.content
|
||
|
|
||
|
mock_friendli_client.chat.completions.create.assert_called_once_with(
|
||
|
messages=[{"role": "user", "content": "Hello langchain"}],
|
||
|
stream=True,
|
||
|
model=chat_friendli.model,
|
||
|
frequency_penalty=None,
|
||
|
presence_penalty=None,
|
||
|
max_tokens=None,
|
||
|
stop=None,
|
||
|
temperature=None,
|
||
|
top_p=None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.requires("friendli")
|
||
|
async def test_friendli_astream(
|
||
|
mock_friendli_async_client: AsyncMock, chat_friendli: ChatFriendli
|
||
|
) -> None:
|
||
|
"""Test async stream with friendli."""
|
||
|
mock_delta_0 = Mock()
|
||
|
mock_delta_0.content = "Hello "
|
||
|
mock_delta_1 = Mock()
|
||
|
mock_delta_1.content = "Friendli"
|
||
|
mock_choice_0 = Mock()
|
||
|
mock_choice_0.delta = mock_delta_0
|
||
|
mock_choice_1 = Mock()
|
||
|
mock_choice_1.delta = mock_delta_1
|
||
|
mock_chunk_0 = Mock()
|
||
|
mock_chunk_0.choices = [mock_choice_0]
|
||
|
mock_chunk_1 = Mock()
|
||
|
mock_chunk_1.choices = [mock_choice_1]
|
||
|
mock_stream = AsyncMock()
|
||
|
mock_chunks = [mock_chunk_0, mock_chunk_1]
|
||
|
mock_stream.__aiter__.return_value = mock_chunks
|
||
|
|
||
|
mock_friendli_async_client.chat.completions.create.return_value = mock_stream
|
||
|
stream = chat_friendli.astream("Hello langchain")
|
||
|
async for i, chunk in aenumerate(stream):
|
||
|
assert chunk.content == mock_chunks[i].choices[0].delta.content
|
||
|
|
||
|
mock_friendli_async_client.chat.completions.create.assert_awaited_once_with(
|
||
|
messages=[{"role": "user", "content": "Hello langchain"}],
|
||
|
stream=True,
|
||
|
model=chat_friendli.model,
|
||
|
frequency_penalty=None,
|
||
|
presence_penalty=None,
|
||
|
max_tokens=None,
|
||
|
stop=None,
|
||
|
temperature=None,
|
||
|
top_p=None,
|
||
|
)
|