mirror of https://github.com/hwchase17/langchain
Description * Refactor Fireworks within Langchain LLMs. * Remove FireworksChat within Langchain LLMs. * Add ChatFireworks (which uses chat completion api) to Langchain chat models. * Users have to install `fireworks-ai` and register an api key to use the api. Issue - Not applicable Dependencies - None Tag maintainer - @rlancemartin @baskaryanpull/11117/head
parent
5514ebe859
commit
6dd44ff1c0
@ -0,0 +1,255 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "642fd21c-600a-47a1-be96-6e1438b421a9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ChatFireworks\n",
|
||||||
|
"\n",
|
||||||
|
">[Fireworks](https://app.fireworks.ai/) accelerates product development on generative AI by creating an innovative AI experiment and production platform. \n",
|
||||||
|
"\n",
|
||||||
|
"This example goes over how to use LangChain to interact with `ChatFireworks` models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d00d850917865298",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models.fireworks import ChatFireworks\n",
|
||||||
|
"from langchain.schema import SystemMessage, HumanMessage\n",
|
||||||
|
"import os"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "f28ebf8b-f14f-46c7-9962-8b8dc42e31be",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Setup\n",
|
||||||
|
"Contact Fireworks AI for the an API Key to access our models\n",
|
||||||
|
"\n",
|
||||||
|
"Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d096fb14-8acc-4047-9cd0-c842430c3a1d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Initialize a Fireworks Chat model\n",
|
||||||
|
"os.environ['FIREWORKS_API_KEY'] = \"<your_api_key>\" # Change this to your own API key\n",
|
||||||
|
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d8f13144-37cf-47a5-b5a0-e3cdf76d9a72",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Calling the Model\n",
|
||||||
|
"\n",
|
||||||
|
"You can use the LLMs to call the model for specified message(s). \n",
|
||||||
|
"\n",
|
||||||
|
"See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "72340871-ae2f-415f-b399-0777d32dc379",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# ChatFireworks Wrapper\n",
|
||||||
|
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
||||||
|
"human_message = HumanMessage(content=\"Who are you?\")\n",
|
||||||
|
"response = chat([system_message, human_message])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "2d6ef879-69e3-422b-8379-bb980b70fe55",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. My primary function is to assist users with tasks and answer questions to the best of my ability. I am capable of understanding and responding to natural language input, and I am here to help you with any questions or tasks you may have. Is there anything specific you would like to know or discuss?\", additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "68c6b1fa-2ff7-4a63-8d88-3cec302180b8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Setting additional parameters: temperature, max_tokens, top_p\n",
|
||||||
|
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n",
|
||||||
|
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
|
||||||
|
"human_message = HumanMessage(content=\"How's the weather today?\")\n",
|
||||||
|
"response = chat([system_message, human_message])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "a09025f8-e4c3-4005-a8fc-c9c774b03a64",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"Oh, you know, it's just another beautiful day in the virtual world! The sun\", additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d93aa186-39cf-4e1a-aa32-01ed31d43bc8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ChatFireworks Wrapper with generate"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatFireworks()\n",
|
||||||
|
"message = HumanMessage(content=\"Hello\")\n",
|
||||||
|
"response = chat.generate([[message], [message]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "35109f36-9519-47a6-a223-25639123e836",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[ChatGeneration(text=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", additional_kwargs={}, example=False))], [ChatGeneration(text=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", additional_kwargs={}, example=False))]], llm_output={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, run=[RunInfo(run_id=UUID('f137463e-e1c7-454a-8b85-b999ce20e0f2')), RunInfo(run_id=UUID('f3ef1138-92de-4e01-900b-991e34a647a7'))])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "92c2cabb-9eaf-4c49-b0e5-a5de5a7d920e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ChatFireworks Wrapper with stream"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "12717a29-fb7d-4a4d-860b-40435452b065",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"Hello! I'm just\n",
|
||||||
|
" an AI assistant,\n",
|
||||||
|
" here to help answer your\n",
|
||||||
|
" questions and provide information in\n",
|
||||||
|
" a responsible and respectful manner\n",
|
||||||
|
". I'm not able\n",
|
||||||
|
" to access personal information or provide\n",
|
||||||
|
" any content that could be considered\n",
|
||||||
|
" harmful, uneth\n",
|
||||||
|
"ical, racist, sex\n",
|
||||||
|
"ist, toxic, dangerous\n",
|
||||||
|
", or illegal. My purpose\n",
|
||||||
|
" is to assist and provide helpful\n",
|
||||||
|
" responses that are socially un\n",
|
||||||
|
"biased and positive in nature\n",
|
||||||
|
". Is there something specific you\n",
|
||||||
|
" would like to know or discuss\n",
|
||||||
|
"?\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm = ChatFireworks()\n",
|
||||||
|
"\n",
|
||||||
|
"for token in llm.stream(\"Who are you\"):\n",
|
||||||
|
" print(token.content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "02991e05-a38e-47d4-9ab3-7e630a8ead55",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,264 @@
|
|||||||
|
import fireworks
|
||||||
|
import fireworks.client
|
||||||
|
from langchain.utils.env import get_from_dict_or_env
|
||||||
|
from pydantic import root_validator
|
||||||
|
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
|
||||||
|
from langchain.adapters.openai import convert_message_to_dict
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.llms.base import create_base_retry_decorator
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
"""Convert a delta response to a message chunk."""
|
||||||
|
role = _dict.role
|
||||||
|
content = _dict.content or ""
|
||||||
|
additional_kwargs = {}
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=content, name=_dict.name)
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
"""Convert a dict response to a message."""
|
||||||
|
role = _dict.role
|
||||||
|
content = _dict.content or ""
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=content)
|
||||||
|
elif role == "assistant":
|
||||||
|
content = _dict.content
|
||||||
|
additional_kwargs = {}
|
||||||
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=content)
|
||||||
|
elif role == "function":
|
||||||
|
return FunctionMessage(content=content, name=_dict.name)
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=content, role=role)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFireworks(BaseChatModel):
|
||||||
|
"""Fireworks Chat models."""
|
||||||
|
|
||||||
|
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||||
|
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
|
||||||
|
fireworks_api_key: Optional[str] = None
|
||||||
|
max_retries: int = 20
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key in environment."""
|
||||||
|
fireworks_api_key = get_from_dict_or_env(
|
||||||
|
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||||
|
)
|
||||||
|
fireworks.client.api_key = fireworks_api_key
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "fireworks-chat"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": message_dicts,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
response = completion_with_retry(self, **params)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": message_dicts,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
response = await acompletion_with_retry(self, **params)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
return llm_outputs[0]
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for res in response.choices:
|
||||||
|
message = convert_dict_to_message(res.message)
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=dict(finish_reason=res.finish_reason),
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
llm_output = {"model": self.model}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]]]:
|
||||||
|
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": message_dicts,
|
||||||
|
"stream": True,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
for chunk in completion_with_retry(self, **params):
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
generation_info = (
|
||||||
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
|
)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": message_dicts,
|
||||||
|
"stream": True,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
async for chunk in await acompletion_with_retry_streaming(self, **params):
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
generation_info = (
|
||||||
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
|
)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(
|
||||||
|
llm: ChatFireworks,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return fireworks.client.ChatCompletion.create(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(
|
||||||
|
llm: ChatFireworks,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return await fireworks.client.ChatCompletion.acreate(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry_streaming(
|
||||||
|
llm: ChatFireworks,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call for streaming."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return fireworks.client.ChatCompletion.acreate(
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(
|
||||||
|
llm: ChatFireworks,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
"""Define retry mechanism."""
|
||||||
|
errors = [
|
||||||
|
fireworks.client.error.RateLimitError,
|
||||||
|
fireworks.client.error.ServiceUnavailableError,
|
||||||
|
]
|
||||||
|
return create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
@ -0,0 +1,106 @@
|
|||||||
|
"""Test ChatFireworks wrapper."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chat_models.fireworks import ChatFireworks
|
||||||
|
from langchain.schema import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatResult,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
|
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fireworks() -> None:
|
||||||
|
"""Test ChatFireworks wrapper."""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
message = HumanMessage(content="What is the weather in Redwood City, CA today")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fireworks_model() -> None:
|
||||||
|
"""Test ChatFireworks wrapper handles model_name."""
|
||||||
|
chat = ChatFireworks(model="foo")
|
||||||
|
assert chat.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fireworks_system_message() -> None:
|
||||||
|
"""Test ChatFireworks wrapper with system message."""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
system_message = SystemMessage(content="You are to chat with the user.")
|
||||||
|
human_message = HumanMessage(content="Hello")
|
||||||
|
response = chat([system_message, human_message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fireworks_generate() -> None:
|
||||||
|
"""Test ChatFireworks wrapper with generate."""
|
||||||
|
chat = ChatFireworks(model_kwargs={"n": 2})
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat.generate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 2
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fireworks_multiple_completions() -> None:
|
||||||
|
"""Test ChatFireworks wrapper with multiple completions."""
|
||||||
|
chat = ChatFireworks(model_kwargs={"n": 5})
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat._generate([message])
|
||||||
|
assert isinstance(response, ChatResult)
|
||||||
|
assert len(response.generations) == 5
|
||||||
|
for generation in response.generations:
|
||||||
|
assert isinstance(generation.message, BaseMessage)
|
||||||
|
assert isinstance(generation.message.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fireworks_llm_output_contains_model_id() -> None:
|
||||||
|
"""Test llm_output contains model_id."""
|
||||||
|
chat = ChatFireworks()
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
llm_result = chat.generate([[message]])
|
||||||
|
assert llm_result.llm_output is not None
|
||||||
|
assert llm_result.llm_output["model"] == chat.model
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_streaming() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_fireworks_agenerate() -> None:
|
||||||
|
"""Test ChatFireworks wrapper with generate."""
|
||||||
|
chat = ChatFireworks(model_kwargs={"n": 2})
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 2
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fireworks_astream() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
async for token in llm.astream("Who's the best quarterback in the NFL?"):
|
||||||
|
assert isinstance(token.content, str)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue