|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
"""Test ChatFireworks wrapper."""
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
@ -6,25 +7,34 @@ from langchain.chat_models.fireworks import ChatFireworks
|
|
|
|
|
from langchain.schema import ChatGeneration, ChatResult, LLMResult
|
|
|
|
|
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
|
|
|
|
|
|
|
|
|
if sys.version_info < (3, 9):
|
|
|
|
|
pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True)
|
|
|
|
|
|
|
|
|
|
def test_chat_fireworks() -> None:
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def chat() -> ChatFireworks:
|
|
|
|
|
return ChatFireworks(model_kwargs={"temperature": 0, "max_tokens": 512})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_chat_fireworks(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test ChatFireworks wrapper."""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
message = HumanMessage(content="What is the weather in Redwood City, CA today")
|
|
|
|
|
response = chat([message])
|
|
|
|
|
assert isinstance(response, BaseMessage)
|
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_chat_fireworks_model() -> None:
|
|
|
|
|
"""Test ChatFireworks wrapper handles model_name."""
|
|
|
|
|
chat = ChatFireworks(model="foo")
|
|
|
|
|
assert chat.model == "foo"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat_fireworks_system_message() -> None:
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_chat_fireworks_system_message(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test ChatFireworks wrapper with system message."""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
system_message = SystemMessage(content="You are to chat with the user.")
|
|
|
|
|
human_message = HumanMessage(content="Hello")
|
|
|
|
|
response = chat([system_message, human_message])
|
|
|
|
@ -32,6 +42,7 @@ def test_chat_fireworks_system_message() -> None:
|
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_chat_fireworks_generate() -> None:
|
|
|
|
|
"""Test ChatFireworks wrapper with generate."""
|
|
|
|
|
chat = ChatFireworks(model_kwargs={"n": 2})
|
|
|
|
@ -47,6 +58,7 @@ def test_chat_fireworks_generate() -> None:
|
|
|
|
|
assert generation.text == generation.message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_chat_fireworks_multiple_completions() -> None:
|
|
|
|
|
"""Test ChatFireworks wrapper with multiple completions."""
|
|
|
|
|
chat = ChatFireworks(model_kwargs={"n": 5})
|
|
|
|
@ -59,35 +71,35 @@ def test_chat_fireworks_multiple_completions() -> None:
|
|
|
|
|
assert isinstance(generation.message.content, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat_fireworks_llm_output_contains_model_id() -> None:
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_chat_fireworks_llm_output_contains_model_id(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test llm_output contains model_id."""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
message = HumanMessage(content="Hello")
|
|
|
|
|
llm_result = chat.generate([[message]])
|
|
|
|
|
assert llm_result.llm_output is not None
|
|
|
|
|
assert llm_result.llm_output["model"] == chat.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fireworks_invoke() -> None:
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_fireworks_invoke(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Tests chat completion with invoke"""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
result = chat.invoke("How is the weather in New York today?", stop=[","])
|
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
|
assert result.content[-1] == ","
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_fireworks_ainvoke() -> None:
|
|
|
|
|
async def test_fireworks_ainvoke(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Tests chat completion with invoke"""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
result = await chat.ainvoke("How is the weather in New York today?", stop=[","])
|
|
|
|
|
assert isinstance(result.content, str)
|
|
|
|
|
assert result.content[-1] == ","
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fireworks_batch() -> None:
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_fireworks_batch(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test batch tokens from ChatFireworks."""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
result = chat.batch(
|
|
|
|
|
[
|
|
|
|
|
"What is the weather in Redwood City, CA today",
|
|
|
|
@ -105,10 +117,10 @@ def test_fireworks_batch() -> None:
|
|
|
|
|
assert token.content[-1] == ","
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_fireworks_abatch() -> None:
|
|
|
|
|
async def test_fireworks_abatch(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test batch tokens from ChatFireworks."""
|
|
|
|
|
chat = ChatFireworks()
|
|
|
|
|
result = await chat.abatch(
|
|
|
|
|
[
|
|
|
|
|
"What is the weather in Redwood City, CA today",
|
|
|
|
@ -126,25 +138,26 @@ async def test_fireworks_abatch() -> None:
|
|
|
|
|
assert token.content[-1] == ","
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fireworks_streaming() -> None:
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_fireworks_streaming(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test streaming tokens from Fireworks."""
|
|
|
|
|
llm = ChatFireworks()
|
|
|
|
|
|
|
|
|
|
for token in llm.stream("I'm Pickle Rick"):
|
|
|
|
|
for token in chat.stream("I'm Pickle Rick"):
|
|
|
|
|
assert isinstance(token.content, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fireworks_streaming_stop_words() -> None:
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
def test_fireworks_streaming_stop_words(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test streaming tokens with stop words."""
|
|
|
|
|
llm = ChatFireworks()
|
|
|
|
|
|
|
|
|
|
last_token = ""
|
|
|
|
|
for token in llm.stream("I'm Pickle Rick", stop=[","]):
|
|
|
|
|
for token in chat.stream("I'm Pickle Rick", stop=[","]):
|
|
|
|
|
last_token = token.content
|
|
|
|
|
assert isinstance(token.content, str)
|
|
|
|
|
assert last_token[-1] == ","
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_chat_fireworks_agenerate() -> None:
|
|
|
|
|
"""Test ChatFireworks wrapper with generate."""
|
|
|
|
@ -161,13 +174,13 @@ async def test_chat_fireworks_agenerate() -> None:
|
|
|
|
|
assert generation.text == generation.message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_fireworks_astream() -> None:
|
|
|
|
|
async def test_fireworks_astream(chat: ChatFireworks) -> None:
|
|
|
|
|
"""Test streaming tokens from Fireworks."""
|
|
|
|
|
llm = ChatFireworks()
|
|
|
|
|
|
|
|
|
|
last_token = ""
|
|
|
|
|
async for token in llm.astream(
|
|
|
|
|
async for token in chat.astream(
|
|
|
|
|
"Who's the best quarterback in the NFL?", stop=[","]
|
|
|
|
|
):
|
|
|
|
|
last_token = token.content
|
|
|
|
|