mirror of https://github.com/hwchase17/langchain
Add fireworks chat model (#11117)
commit
8e4dbae428
@ -0,0 +1,255 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "642fd21c-600a-47a1-be96-6e1438b421a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatFireworks\n",
|
||||
"\n",
|
||||
">[Fireworks](https://app.fireworks.ai/) accelerates product development on generative AI by creating an innovative AI experiment and production platform. \n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with `ChatFireworks` models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d00d850917865298",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models.fireworks import ChatFireworks\n",
|
||||
"from langchain.schema import SystemMessage, HumanMessage\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f28ebf8b-f14f-46c7-9962-8b8dc42e31be",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Setup\n",
|
||||
"Contact Fireworks AI for the an API Key to access our models\n",
|
||||
"\n",
|
||||
"Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d096fb14-8acc-4047-9cd0-c842430c3a1d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize a Fireworks Chat model\n",
|
||||
"os.environ['FIREWORKS_API_KEY'] = \"<your_api_key>\" # Change this to your own API key\n",
|
||||
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d8f13144-37cf-47a5-b5a0-e3cdf76d9a72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Calling the Model\n",
|
||||
"\n",
|
||||
"You can use the LLMs to call the model for specified message(s). \n",
|
||||
"\n",
|
||||
"See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "72340871-ae2f-415f-b399-0777d32dc379",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ChatFireworks Wrapper\n",
|
||||
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
||||
"human_message = HumanMessage(content=\"Who are you?\")\n",
|
||||
"response = chat([system_message, human_message])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "2d6ef879-69e3-422b-8379-bb980b70fe55",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. My primary function is to assist users with tasks and answer questions to the best of my ability. I am capable of understanding and responding to natural language input, and I am here to help you with any questions or tasks you may have. Is there anything specific you would like to know or discuss?\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "68c6b1fa-2ff7-4a63-8d88-3cec302180b8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
||||
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n",
|
||||
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
||||
"human_message = HumanMessage(content=\"How's the weather today?\")\n",
|
||||
"response = chat([system_message, human_message])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a09025f8-e4c3-4005-a8fc-c9c774b03a64",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Oh, you know, it's just another beautiful day in the virtual world! The sun\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d93aa186-39cf-4e1a-aa32-01ed31d43bc8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatFireworks Wrapper with generate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatFireworks()\n",
|
||||
"message = HumanMessage(content=\"Hello\")\n",
|
||||
"response = chat.generate([[message], [message]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "35109f36-9519-47a6-a223-25639123e836",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[ChatGeneration(text=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", additional_kwargs={}, example=False))], [ChatGeneration(text=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", additional_kwargs={}, example=False))]], llm_output={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, run=[RunInfo(run_id=UUID('f137463e-e1c7-454a-8b85-b999ce20e0f2')), RunInfo(run_id=UUID('f3ef1138-92de-4e01-900b-991e34a647a7'))])"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "92c2cabb-9eaf-4c49-b0e5-a5de5a7d920e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatFireworks Wrapper with stream"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "12717a29-fb7d-4a4d-860b-40435452b065",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Hello! I'm just\n",
|
||||
" an AI assistant,\n",
|
||||
" here to help answer your\n",
|
||||
" questions and provide information in\n",
|
||||
" a responsible and respectful manner\n",
|
||||
". I'm not able\n",
|
||||
" to access personal information or provide\n",
|
||||
" any content that could be considered\n",
|
||||
" harmful, uneth\n",
|
||||
"ical, racist, sex\n",
|
||||
"ist, toxic, dangerous\n",
|
||||
", or illegal. My purpose\n",
|
||||
" is to assist and provide helpful\n",
|
||||
" responses that are socially un\n",
|
||||
"biased and positive in nature\n",
|
||||
". Is there something specific you\n",
|
||||
" would like to know or discuss\n",
|
||||
"?\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = ChatFireworks()\n",
|
||||
"\n",
|
||||
"for token in llm.stream(\"Who are you\"):\n",
|
||||
" print(token.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "02991e05-a38e-47d4-9ab3-7e630a8ead55",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,292 @@
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain.utils.env import get_from_dict_or_env
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Any, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
"""Convert a delta response to a message chunk."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
additional_kwargs: Dict = {}
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict.name)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Any) -> BaseMessage:
|
||||
"""Convert a dict response to a message."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
content = _dict.content
|
||||
additional_kwargs: Dict = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=content, name=_dict.name)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
class ChatFireworks(BaseChatModel):
|
||||
"""Fireworks Chat models."""
|
||||
|
||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
model_kwargs: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 512,
|
||||
"top_p": 1,
|
||||
}.copy()
|
||||
)
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 20
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key in environment."""
|
||||
try:
|
||||
import fireworks.client
|
||||
except ImportError as e:
|
||||
raise ImportError("") from e
|
||||
fireworks_api_key = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
fireworks.client.api_key = fireworks_api_key
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks-chat"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(self, run_manager=run_manager, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(self, run_manager=run_manager, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
if llm_outputs[0] is None:
|
||||
return {}
|
||||
return llm_outputs[0]
|
||||
|
||||
def _create_chat_result(self, response: Any) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = convert_dict_to_message(res.message)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.finish_reason),
|
||||
)
|
||||
generations.append(gen)
|
||||
llm_output = {"model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
default_chunk_class = AIMessageChunk
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
for chunk in completion_with_retry(self, run_manager=run_manager, **params):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
finish_reason = choice.finish_reason
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
default_chunk_class = AIMessageChunk
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
async for chunk in await acompletion_with_retry_streaming(
|
||||
self, run_manager=run_manager, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
finish_reason = choice.finish_reason
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call for streaming."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define retry mechanism."""
|
||||
import fireworks.client
|
||||
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
@ -1,377 +1,245 @@
|
||||
"""Wrapper around Fireworks APIs"""
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFireworks(BaseLLM):
|
||||
"""Wrapper around Fireworks large language models."""
|
||||
|
||||
model_id: str = Field("accounts/fireworks/models/llama-v2-7b-chat", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate in the completion.
|
||||
-1 returns as many tokens as possible given the prompt and
|
||||
the models maximal context size."""
|
||||
top_p: float = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.utils.env import get_from_dict_or_env
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Any,
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
return GenerationChunk(
|
||||
text=stream_response.choices[0].text,
|
||||
generation_info=dict(
|
||||
finish_reason=stream_response.choices[0].finish_reason,
|
||||
logprobs=stream_response.choices[0].logprobs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Fireworks(LLM):
|
||||
"""Fireworks models."""
|
||||
|
||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
model_kwargs: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 512,
|
||||
"top_p": 1,
|
||||
}.copy()
|
||||
)
|
||||
fireworks_api_key: Optional[str] = None
|
||||
"""Api key to use fireworks API"""
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
def __new__(cls, **data: Any) -> Any:
|
||||
"""Initialize the Fireworks object."""
|
||||
data.get("model_id", "")
|
||||
return super().__new__(cls)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
max_retries: int = 20
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||
"""Validate that api key in environment."""
|
||||
try:
|
||||
import fireworks.client
|
||||
except ImportError as e:
|
||||
raise ImportError("") from e
|
||||
fireworks_api_key = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
fireworks.client.api_key = fireworks_api_key
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint with k unique prompts.
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The full LLM output.
|
||||
"""
|
||||
params = {"model": self.model_id}
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
response = completion_with_retry(self, prompt=prompts, **params)
|
||||
choices.extend(response)
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
async def _agenerate(
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params: dict = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(self, run_manager=run_manager, **params)
|
||||
|
||||
return response.choices[0].text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Fireworks endpoint async with k unique prompts."""
|
||||
params = {"model": self.model_id}
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response)
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
def get_batch_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
sub_prompts = [
|
||||
prompts[i : i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(
|
||||
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||
) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
|
||||
for i, _ in enumerate(prompts):
|
||||
sub_choices = choices[i : (i + 1)]
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=choice,
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks"
|
||||
|
||||
|
||||
class FireworksChat(BaseLLM):
|
||||
"""Wrapper around Fireworks Chat large language models.
|
||||
To use, you should have the ``fireworksai`` python package installed, and the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
Any parameters that are valid to be passed to the fireworks.create
|
||||
call can be passed in, even if not explicitly saved on this class.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms import FireworksChat
|
||||
fireworkschat = FireworksChat(model_id=""llama-v2-13b-chat"")
|
||||
"""
|
||||
|
||||
model_id: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate in the completion.
|
||||
-1 returns as many tokens as possible given the prompt and
|
||||
the models maximal context size."""
|
||||
top_p: float = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
fireworks_api_key: Optional[str] = None
|
||||
max_retries: int = 6
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||
"""Maximum number of retries to make when generating."""
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
"""Series of messages for Chat input."""
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(self, run_manager=run_manager, **params)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment"""
|
||||
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||
)
|
||||
return values
|
||||
return response.choices[0].text
|
||||
|
||||
def _get_chat_params(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> Tuple:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError(
|
||||
f"FireworksChat currently only supports single prompt, got {prompts}"
|
||||
)
|
||||
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
||||
params: Dict[str, Any] = {**{"model": self.model_id}}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
return messages, params
|
||||
|
||||
def _generate(
|
||||
def _stream(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = completion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"model_id": self.model_id,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=full_response[0])]],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
for stream_resp in completion_with_retry(
|
||||
self, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
async def _astream(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||
llm_output = {
|
||||
"model_id": self.model_id,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=full_response[0])]],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
async for stream_resp in await acompletion_with_retry_streaming(
|
||||
self, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks-chat"
|
||||
|
||||
|
||||
class Fireworks(BaseFireworks):
|
||||
"""Wrapper around Fireworks large language models.
|
||||
To use, you should have the ``fireworks`` python package installed, and the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
Any parameters that are valid to be passed to the fireworks.create
|
||||
call can be passed in, even if not explicitly saved on this class.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms import fireworks
|
||||
llm = Fireworks(model_id="llama-v2-13b")
|
||||
"""
|
||||
|
||||
|
||||
def update_token_usage(
|
||||
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Update token usage."""
|
||||
_keys_to_use = keys.intersection(response)
|
||||
for _key in _keys_to_use:
|
||||
if _key not in token_usage:
|
||||
token_usage[_key] = response["usage"][_key]
|
||||
else:
|
||||
token_usage[_key] += response["usage"][_key]
|
||||
|
||||
|
||||
def execute(
|
||||
prompt: str,
|
||||
model: str,
|
||||
api_key: Optional[str],
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
) -> Any:
|
||||
"""Execute LLM query"""
|
||||
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
|
||||
requestBody = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
requestHeaders = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
|
||||
return response.text
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompt):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(prompt):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
if "prompt" not in kwargs.keys():
|
||||
answers = []
|
||||
for i in range(len(kwargs["messages"])):
|
||||
result = kwargs["messages"][i]["content"]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
llm.top_p,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
else:
|
||||
answers = []
|
||||
for i in range(len(kwargs["prompt"])):
|
||||
result = kwargs["prompt"][i]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
llm.top_p,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
return answers
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.Completion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
if "prompt" not in kwargs.keys():
|
||||
answers = []
|
||||
for i in range(len(kwargs["messages"])):
|
||||
result = kwargs["messages"][i]["content"]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
else:
|
||||
answers = []
|
||||
for i in range(len(kwargs["prompt"])):
|
||||
result = kwargs["prompt"][i]
|
||||
result = execute(
|
||||
result,
|
||||
kwargs["model"],
|
||||
llm.fireworks_api_key,
|
||||
llm.max_tokens,
|
||||
llm.temperature,
|
||||
)
|
||||
curr_string = json.loads(result)["choices"][0]["text"]
|
||||
answers.append(curr_string)
|
||||
return answers
|
||||
"""Use tenacity to retry the completion call for streaming."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.Completion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: Fireworks,
|
||||
*,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define retry mechanism."""
|
||||
import fireworks.client
|
||||
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
@ -0,0 +1,105 @@
|
||||
"""Test ChatFireworks wrapper."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.fireworks import ChatFireworks
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
def test_chat_fireworks() -> None:
|
||||
"""Test ChatFireworks wrapper."""
|
||||
chat = ChatFireworks()
|
||||
message = HumanMessage(content="What is the weather in Redwood City, CA today")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_fireworks_model() -> None:
|
||||
"""Test ChatFireworks wrapper handles model_name."""
|
||||
chat = ChatFireworks(model="foo")
|
||||
assert chat.model == "foo"
|
||||
|
||||
|
||||
def test_chat_fireworks_system_message() -> None:
|
||||
"""Test ChatFireworks wrapper with system message."""
|
||||
chat = ChatFireworks()
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_fireworks_generate() -> None:
|
||||
"""Test ChatFireworks wrapper with generate."""
|
||||
chat = ChatFireworks(model_kwargs={"n": 2})
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test_chat_fireworks_multiple_completions() -> None:
|
||||
"""Test ChatFireworks wrapper with multiple completions."""
|
||||
chat = ChatFireworks(model_kwargs={"n": 5})
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
assert len(response.generations) == 5
|
||||
for generation in response.generations:
|
||||
assert isinstance(generation.message, BaseMessage)
|
||||
assert isinstance(generation.message.content, str)
|
||||
|
||||
|
||||
def test_chat_fireworks_llm_output_contains_model_id() -> None:
|
||||
"""Test llm_output contains model_id."""
|
||||
chat = ChatFireworks()
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model"] == chat.model
|
||||
|
||||
|
||||
def test_fireworks_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_fireworks_agenerate() -> None:
|
||||
"""Test ChatFireworks wrapper with generate."""
|
||||
chat = ChatFireworks(model_kwargs={"n": 2})
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
||||
assert isinstance(token.content, str)
|
Loading…
Reference in New Issue