Fireworks integration (#8322)

Description - Integrates Fireworks within Langchain LLMs to allow users
to use Fireworks models with Langchain, mainly for summarization.

Issue - Not applicable
Dependencies - None
Tag maintainer - @rlancemartin

---------

Co-authored-by: Raj Janardhan <rajjanardhan@Rajs-Laptop.attlocal.net>
This commit is contained in:
rjanardhan3 2023-08-01 21:17:26 -07:00 committed by GitHub
parent b574507c51
commit 68113348cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 790 additions and 0 deletions

View File

@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cc6caafa",
"metadata": {},
"source": [
"# Fireworks\n",
"\n",
">[Fireworks](https://www.fireworks.ai/) is an AI startup focused on accelerating product development on generative AI by creating an innovative AI experiment and production platform. \n",
"\n",
"This example goes over how to use LangChain to interact with `Fireworks` models."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "60b6dbb2",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.fireworks import Fireworks, FireworksChat\n",
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
")\n",
"import os"
]
},
{
"cell_type": "markdown",
"id": "ccff689e",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"Contact Fireworks AI for the an API Key to access our models\n",
"\n",
"Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-13b-chat."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9ca87a2e",
"metadata": {},
"outputs": [],
"source": [
"# Initialize a Fireworks LLM\n",
"os.environ['FIREWORKS_API_KEY'] = \"\" #change this to your own API KEY\n",
"llm = Fireworks(model_id=\"fireworks-llama-v2-13b-chat\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "43a11ba8",
"metadata": {},
"outputs": [],
"source": [
"# Create LLM chain\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "markdown",
"id": "acc24d0c",
"metadata": {},
"source": [
"# Calling the Model\n",
"\n",
"You can use the LLMs to call the model for specified prompt(s). \n",
"\n",
"Current Specified Models: \n",
"* fireworks-falcon-7b, fireworks-falcon-40b-w8a16\n",
"* fireworks-guanaco-30b, fireworks-guanaco-33b-w8a16\n",
"* fireworks-llama-7b, fireworks-llama-13b, fireworks-llama-30b-w8a16\n",
"* fireworks-llama-v2-13b, fireworks-llama-v2-13b-chat, fireworks-llama-v2-13b-w8a16, fireworks-llama-v2-13b-chat-w8a16\n",
"* fireworks-llama-v2-7b, fireworks-llama-v2-7b-chat, fireworks-llama-v2-7b-w8a16, fireworks-llama-v2-7b-chat-w8a16"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "bf0a425c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"It's a question that has been debated for years, with different analysts and fans making their cases for various signal-callers. Here are some of the top contenders for the title of best quarterback in the NFL:\n",
"\n",
"1. Tom Brady: The New England Patriots legend has won six Super Bowls and has been named Super Bowl MVP four times. He's known for his precision passing, pocket presence, and ability to read defenses.\n",
"2. Aaron Rodgers: The Green Bay Packers quarterback has won two Super Bowls and has been named NFL MVP twice. He's known for his quick release, accuracy, and ability to extend plays with his feet.\n",
"3. Drew Brees: The New Orleans Saints quarterback has won a Super Bowl and has been named NFL MVP once. He's known for his accuracy, pocket presence, and ability to read defenses.\n",
"4. Patrick Mahomes: The Kansas City Chiefs quarterback has won a Super Bowl and has been named NFL MVP twice. He's known for his arm strength, athleticism, and ability to make plays outside of the pocket.\n",
"5. Russell Wilson: The Seattle Seahawks quarterback has won a Super Bowl and has been named NFL MVP once. He's known for his mobility, accuracy, and ability to extend plays with his feet.\n",
"\n",
"Of course, there are other talented quarterbacks in the league, such as Lamar Jackson, Deshaun Watson, and Carson Wentz, who could also be considered among the best. Ultimately, the answer to the question of who's the best quarterback in the NFL is subjective and can vary depending on individual perspectives and criteria.\n"
]
}
],
"source": [
"#single prompt\n",
"output = llm(\"Who's the best quarterback in the NFL?\")\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "afc7de6f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"generations=[[Generation(text=\"\\nWho is the best cricket player in the world in 2016?\\nThe best cricket player in the world in 2016 is Virat Kohli. The Indian captain has had a fabulous year, scoring heavily in all formats of the game, leading India to several victories, and breaking several records. In Test cricket, Kohli has scored 1215 runs at an average of 75.33 with 6 centuries and 4 fifties, which is the highest number of runs scored by any player in a calendar year. In ODI cricket, he has scored 1143 runs at an average of 83.42 with 7 centuries and 6 fifties, which is also the highest number of runs scored by any player in a calendar year. Additionally, Kohli has led India to the number one ranking in Test cricket, and has been named the ICC Test Player of the Year and the ICC ODI Player of the Year.\\nVirat Kohli has been in incredible form in 2016, and his performances have made him the standout player of the year. Other players who have had a great year include Steve Smith, Joe Root, and Kane Williamson, but Kohli's consistency and dominance in all formats of the game make him the best cricket player in the world in 2016.\", generation_info=None)], [Generation(text=\"\\n\\nA: LeBron James.\\n\\nB: Kevin Durant.\\n\\nC: Steph Curry.\\n\\nD: James Harden.\\n\\nE: Other (please specify).\\n\\nWhat's your answer?\", generation_info=None)]] llm_output={'token_usage': {}, 'model_id': 'fireworks-llama-v2-13b-chat'} run=[RunInfo(run_id=UUID('d14b6bee-7692-46ad-8798-acb6f72fc7fb')), RunInfo(run_id=UUID('b9f5b3b5-9e62-4eaf-b269-ecf0cbbcfb82'))]\n"
]
}
],
"source": [
"#calling multiple prompts\n",
"output = llm.generate([\"Who's the best cricket player in 2016?\", \"Who's the best basketball player in the league?\"])\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b801c20d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Kansas City in December can be quite chilly, with average high\n"
]
}
],
"source": [
"#setting parameters: model_id, temperature, max_tokens, top_p\n",
"llm = Fireworks(model_id=\"fireworks-llama-v2-13b-chat\", temperature=0.7, max_tokens=15, top_p=1.0)\n",
"print(llm(\"What's the weather like in Kansas City in December?\"))"
]
},
{
"cell_type": "markdown",
"id": "137662a6",
"metadata": {},
"source": [
"# Create and Run Chain\n",
"\n",
"Create a prompt template to be used with the LLM Chain. Once this prompt template is created, initialize the chain with the LLM and prompt template, and run the chain with the specified prompts."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "fd2c6bc1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"(Note: I'm just an AI and not a branding expert, so take this as a starting point for your own research and brainstorming.)\n",
"A good name for a company that makes football helmets could be:\n",
"\n",
"1. Helix Pro: This name plays off the idea of a helix, or spiral, shape that is commonly associated with football helmets. \"Pro\" implies a professional-level product.\n",
"2. Gridiron Gear: \"Gridiron\" is a term used to describe a football field, and \"gear\" highlights the company's focus on producing high-quality football helmets.\n",
"3. Linebacker Lab: \"Linebacker\" is a position on the football field, and \"Lab\" suggests a focus on research and development.\n",
"4. Helmet Hut: This name is simple and easy to remember, and it immediately conveys the company's focus on football helmets.\n",
"5. Tackle Tech: \"Tackle\" is a term used in football to describe a hit or collision, and \"Tech\" implies a focus on advanced technology and innovation.\n",
"6. Victory Vest: \"Victory\" implies a focus on winning and success, and \"Vest\" could suggest a protective or armored design.\n",
"7. Pigskin Pro: \"Pigskin\" is a term used to describe a football, and \"Pro\" implies a professional-level product.\n",
"8. Football Fusion: This name could suggest a combination of different materials or technologies to create a high-quality football helmet.\n",
"9. Endzone Edge: \"Endzone\" is the area of the football field where a team scores a touchdown, and \"Edge\" implies a competitive advantage.\n",
"10. MVP Masks: \"MVP\" stands for \"Most Valuable Player,\" and \"Masks\" highlights the protective nature of the company's football helmets.\n",
"\n",
"Remember, the name you choose for your company should be memorable, easy to pronounce and spell, and convey a sense of quality and professionalism. It's also important to check that the name isn't already in use by another company, and to consider any potential trademark issues.\n"
]
}
],
"source": [
"human_message_prompt = HumanMessagePromptTemplate(\n",
" prompt=PromptTemplate(\n",
" template=\"What is a good name for a company that makes {product}?\",\n",
" input_variables=[\"product\"],\n",
" )\n",
")\n",
"\n",
"chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])\n",
"chat = Fireworks()\n",
"chain = LLMChain(llm=chat, prompt=chat_prompt_template)\n",
"output = chain.run(\"football helmets\")\n",
"\n",
"print(output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,22 @@
# Fireworks
This page covers how to use the Fireworks models within Langchain.
## Installation and Setup
- To use the Fireworks model, you need to have a Fireworks API key. To generate one, sign up at platform.fireworks.ai
- Authenticate by setting the FIREWORKS_API_KEY environment variable.
## LLM
Fireworks integrates with Langchain through the LLM module, which allows for standardized usage of any models deployed on the Fireworks models.
In this example, we'll work the llama-v2-13b.
```python
from langchain.llms.fireworks import Fireworks
llm = Fireworks(model="fireworks-llama-v2-13b-chat", max_tokens=256, temperature=0.4)
llm("Name 3 sports.")
```
For a more detailed walkthrough, see [here](/docs/extras/modules/model_io/models/llms/integrations/Fireworks.ipynb).

View File

@ -39,6 +39,7 @@ from langchain.llms.ctransformers import CTransformers
from langchain.llms.databricks import Databricks from langchain.llms.databricks import Databricks
from langchain.llms.deepinfra import DeepInfra from langchain.llms.deepinfra import DeepInfra
from langchain.llms.fake import FakeListLLM from langchain.llms.fake import FakeListLLM
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.forefrontai import ForefrontAI from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI from langchain.llms.gooseai import GooseAI
@ -98,6 +99,8 @@ __all__ = [
"Databricks", "Databricks",
"DeepInfra", "DeepInfra",
"FakeListLLM", "FakeListLLM",
"Fireworks",
"FireworksChat",
"ForefrontAI", "ForefrontAI",
"GPT4All", "GPT4All",
"GooglePalm", "GooglePalm",

View File

@ -0,0 +1,377 @@
"""Wrapper around Fireworks APIs"""
import json
import logging
from typing import (
Any,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
import requests
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class BaseFireworks(BaseLLM):
"""Wrapper around Fireworks large language models."""
model_id: str = Field("fireworks-llama-v2-7b-chat", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
fireworks_api_key: Optional[str] = None
"""Api key to use fireworks API"""
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
def __new__(cls, **data: Any) -> Any:
"""Initialize the Fireworks object."""
data.get("model_id", "")
return super().__new__(cls)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["fireworks_api_key"] = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
return values
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = completion_with_retry(self, prompt=prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
def get_batch_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice,
)
for choice in sub_choices
]
)
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
return LLMResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks"
class FireworksChat(BaseLLM):
"""Wrapper around Fireworks Chat large language models.
To use, you should have the ``fireworksai`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import FireworksChat
fireworkschat = FireworksChat(model_id=""fireworks-llama-v2-13b-chat"")
"""
model_id: str = "fireworks-llama-v2-7b-chat"
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
fireworks_api_key: Optional[str] = None
max_retries: int = 6
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
"""Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list)
"""Series of messages for Chat input."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment"""
values["fireworks_api_key"] = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
return values
def _get_chat_params(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> Tuple:
if len(prompts) > 1:
raise ValueError(
f"FireworksChat currently only supports single prompt, got {prompts}"
)
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = {**{"model": self.model_id}}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
return messages, params
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
class Fireworks(BaseFireworks):
"""Wrapper around Fireworks large language models.
To use, you should have the ``fireworks`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import fireworks
llm = Fireworks(model_id="fireworks-llama-v2-13b")
"""
def update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response)
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
def execute(
prompt: str,
model: str,
api_key: Optional[str],
max_tokens: int = 256,
temperature: float = 0.0,
top_p: float = 1.0,
) -> Any:
"""Execute LLM query"""
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
requestBody = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
requestHeaders = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
return response.text
def completion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers
async def acompletion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
) -> Any:
"""Use tenacity to retry the async completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers

View File

@ -0,0 +1,157 @@
"""Test Fireworks AI API Wrapper."""
from pathlib import Path
import pytest
from langchain import LLMChain, PromptTemplate
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAIChat
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.loading import load_llm
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import LLMResult
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import DeepLake
def test_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = Fireworks(model_id="fireworks-llama-v2-13b-chat", max_tokens=900)
output = llm("What is the weather in NYC")
assert isinstance(output, str)
def test_fireworks_in_chain() -> None:
"""Tests fireworks AI in a Langchain chain"""
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template="What is a good name for a company that makes {product}?",
input_variables=["product"],
)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
chat = Fireworks()
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
output = chain.run("football helmets")
assert isinstance(output, str)
@pytest.mark.asyncio
async def test_openai_chat_async_generate() -> None:
"""Test async chat."""
llm = OpenAIChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
def test_fireworks_model_param() -> None:
"""Tests model parameters for Fireworks"""
llm = Fireworks(model="foo")
assert llm.model_id == "foo"
llm = Fireworks(model_id="foo")
assert llm.model_id == "foo"
def test_fireworkschat_model_param() -> None:
"""Tests model parameters for FireworksChat"""
llm = FireworksChat(model="foo")
assert llm.model_id == "foo"
llm = FireworksChat(model_id="foo")
assert llm.model_id == "foo"
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an Fireworks LLM."""
llm = Fireworks(max_tokens=10)
llm.save(file_path=tmp_path / "fireworks.yaml")
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
assert loaded_llm == llm
def test_fireworks_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = Fireworks()
output = llm.generate(["How is the weather in New York today?", "I'm pickle rick"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
def test_fireworks_chat() -> None:
"""Test FireworksChat."""
llm = FireworksChat()
output = llm("Name me 3 quick facts about the New England Patriots")
assert isinstance(output, str)
async def test_fireworks_agenerate() -> None:
llm = Fireworks()
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
async def test_fireworkschat_agenerate() -> None:
llm = FireworksChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 1
def test_fireworkschat_chain() -> None:
embeddings = OpenAIEmbeddings()
loader = TextLoader(
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
db = DeepLake(
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
)
db.add_documents(docs)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search(query)
qa = RetrievalQA.from_chain_type(
llm=FireworksChat(),
chain_type="stuff",
retriever=db.as_retriever(),
)
query = "What did the president say about Ketanji Brown Jackson"
output = qa.run(query)
assert isinstance(output, str)
_EXPECTED_NUM_TOKENS = {
"fireworks-llama-v2-13b": 17,
"fireworks-llama-v2-7b": 17,
"fireworks-llama-v2-13b-chat": 17,
"fireworks-llama-v2-7b-chat": 17,
}
_MODELS = models = [
"fireworks-llama-v2-13b",
"fireworks-llama-v2-7b",
"fireworks-llama-v2-13b-chat",
"fireworks-llama-v2-7b-chat",
]
@pytest.mark.parametrize("model", _MODELS)
def test_fireworks_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = Fireworks(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]