mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Fireworks integration (#8322)
Description - Integrates Fireworks within Langchain LLMs to allow users to use Fireworks models with Langchain, mainly for summarization. Issue - Not applicable Dependencies - None Tag maintainer - @rlancemartin --------- Co-authored-by: Raj Janardhan <rajjanardhan@Rajs-Laptop.attlocal.net>
This commit is contained in:
parent
b574507c51
commit
68113348cc
231
docs/extras/integrations/llms/Fireworks.ipynb
Normal file
231
docs/extras/integrations/llms/Fireworks.ipynb
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cc6caafa",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Fireworks\n",
|
||||||
|
"\n",
|
||||||
|
">[Fireworks](https://www.fireworks.ai/) is an AI startup focused on accelerating product development on generative AI by creating an innovative AI experiment and production platform. \n",
|
||||||
|
"\n",
|
||||||
|
"This example goes over how to use LangChain to interact with `Fireworks` models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"id": "60b6dbb2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms.fireworks import Fireworks, FireworksChat\n",
|
||||||
|
"from langchain import PromptTemplate, LLMChain\n",
|
||||||
|
"from langchain.prompts.chat import (\n",
|
||||||
|
" ChatPromptTemplate,\n",
|
||||||
|
" HumanMessagePromptTemplate,\n",
|
||||||
|
")\n",
|
||||||
|
"import os"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ccff689e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Setup\n",
|
||||||
|
"\n",
|
||||||
|
"Contact Fireworks AI for the an API Key to access our models\n",
|
||||||
|
"\n",
|
||||||
|
"Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-13b-chat."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "9ca87a2e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Initialize a Fireworks LLM\n",
|
||||||
|
"os.environ['FIREWORKS_API_KEY'] = \"\" #change this to your own API KEY\n",
|
||||||
|
"llm = Fireworks(model_id=\"fireworks-llama-v2-13b-chat\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "43a11ba8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create LLM chain\n",
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "acc24d0c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Calling the Model\n",
|
||||||
|
"\n",
|
||||||
|
"You can use the LLMs to call the model for specified prompt(s). \n",
|
||||||
|
"\n",
|
||||||
|
"Current Specified Models: \n",
|
||||||
|
"* fireworks-falcon-7b, fireworks-falcon-40b-w8a16\n",
|
||||||
|
"* fireworks-guanaco-30b, fireworks-guanaco-33b-w8a16\n",
|
||||||
|
"* fireworks-llama-7b, fireworks-llama-13b, fireworks-llama-30b-w8a16\n",
|
||||||
|
"* fireworks-llama-v2-13b, fireworks-llama-v2-13b-chat, fireworks-llama-v2-13b-w8a16, fireworks-llama-v2-13b-chat-w8a16\n",
|
||||||
|
"* fireworks-llama-v2-7b, fireworks-llama-v2-7b-chat, fireworks-llama-v2-7b-w8a16, fireworks-llama-v2-7b-chat-w8a16"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "bf0a425c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"It's a question that has been debated for years, with different analysts and fans making their cases for various signal-callers. Here are some of the top contenders for the title of best quarterback in the NFL:\n",
|
||||||
|
"\n",
|
||||||
|
"1. Tom Brady: The New England Patriots legend has won six Super Bowls and has been named Super Bowl MVP four times. He's known for his precision passing, pocket presence, and ability to read defenses.\n",
|
||||||
|
"2. Aaron Rodgers: The Green Bay Packers quarterback has won two Super Bowls and has been named NFL MVP twice. He's known for his quick release, accuracy, and ability to extend plays with his feet.\n",
|
||||||
|
"3. Drew Brees: The New Orleans Saints quarterback has won a Super Bowl and has been named NFL MVP once. He's known for his accuracy, pocket presence, and ability to read defenses.\n",
|
||||||
|
"4. Patrick Mahomes: The Kansas City Chiefs quarterback has won a Super Bowl and has been named NFL MVP twice. He's known for his arm strength, athleticism, and ability to make plays outside of the pocket.\n",
|
||||||
|
"5. Russell Wilson: The Seattle Seahawks quarterback has won a Super Bowl and has been named NFL MVP once. He's known for his mobility, accuracy, and ability to extend plays with his feet.\n",
|
||||||
|
"\n",
|
||||||
|
"Of course, there are other talented quarterbacks in the league, such as Lamar Jackson, Deshaun Watson, and Carson Wentz, who could also be considered among the best. Ultimately, the answer to the question of who's the best quarterback in the NFL is subjective and can vary depending on individual perspectives and criteria.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"#single prompt\n",
|
||||||
|
"output = llm(\"Who's the best quarterback in the NFL?\")\n",
|
||||||
|
"print(output)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "afc7de6f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"generations=[[Generation(text=\"\\nWho is the best cricket player in the world in 2016?\\nThe best cricket player in the world in 2016 is Virat Kohli. The Indian captain has had a fabulous year, scoring heavily in all formats of the game, leading India to several victories, and breaking several records. In Test cricket, Kohli has scored 1215 runs at an average of 75.33 with 6 centuries and 4 fifties, which is the highest number of runs scored by any player in a calendar year. In ODI cricket, he has scored 1143 runs at an average of 83.42 with 7 centuries and 6 fifties, which is also the highest number of runs scored by any player in a calendar year. Additionally, Kohli has led India to the number one ranking in Test cricket, and has been named the ICC Test Player of the Year and the ICC ODI Player of the Year.\\nVirat Kohli has been in incredible form in 2016, and his performances have made him the standout player of the year. Other players who have had a great year include Steve Smith, Joe Root, and Kane Williamson, but Kohli's consistency and dominance in all formats of the game make him the best cricket player in the world in 2016.\", generation_info=None)], [Generation(text=\"\\n\\nA: LeBron James.\\n\\nB: Kevin Durant.\\n\\nC: Steph Curry.\\n\\nD: James Harden.\\n\\nE: Other (please specify).\\n\\nWhat's your answer?\", generation_info=None)]] llm_output={'token_usage': {}, 'model_id': 'fireworks-llama-v2-13b-chat'} run=[RunInfo(run_id=UUID('d14b6bee-7692-46ad-8798-acb6f72fc7fb')), RunInfo(run_id=UUID('b9f5b3b5-9e62-4eaf-b269-ecf0cbbcfb82'))]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"#calling multiple prompts\n",
|
||||||
|
"output = llm.generate([\"Who's the best cricket player in 2016?\", \"Who's the best basketball player in the league?\"])\n",
|
||||||
|
"print(output)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "b801c20d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"Kansas City in December can be quite chilly, with average high\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"#setting parameters: model_id, temperature, max_tokens, top_p\n",
|
||||||
|
"llm = Fireworks(model_id=\"fireworks-llama-v2-13b-chat\", temperature=0.7, max_tokens=15, top_p=1.0)\n",
|
||||||
|
"print(llm(\"What's the weather like in Kansas City in December?\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "137662a6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Create and Run Chain\n",
|
||||||
|
"\n",
|
||||||
|
"Create a prompt template to be used with the LLM Chain. Once this prompt template is created, initialize the chain with the LLM and prompt template, and run the chain with the specified prompts."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"id": "fd2c6bc1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"(Note: I'm just an AI and not a branding expert, so take this as a starting point for your own research and brainstorming.)\n",
|
||||||
|
"A good name for a company that makes football helmets could be:\n",
|
||||||
|
"\n",
|
||||||
|
"1. Helix Pro: This name plays off the idea of a helix, or spiral, shape that is commonly associated with football helmets. \"Pro\" implies a professional-level product.\n",
|
||||||
|
"2. Gridiron Gear: \"Gridiron\" is a term used to describe a football field, and \"gear\" highlights the company's focus on producing high-quality football helmets.\n",
|
||||||
|
"3. Linebacker Lab: \"Linebacker\" is a position on the football field, and \"Lab\" suggests a focus on research and development.\n",
|
||||||
|
"4. Helmet Hut: This name is simple and easy to remember, and it immediately conveys the company's focus on football helmets.\n",
|
||||||
|
"5. Tackle Tech: \"Tackle\" is a term used in football to describe a hit or collision, and \"Tech\" implies a focus on advanced technology and innovation.\n",
|
||||||
|
"6. Victory Vest: \"Victory\" implies a focus on winning and success, and \"Vest\" could suggest a protective or armored design.\n",
|
||||||
|
"7. Pigskin Pro: \"Pigskin\" is a term used to describe a football, and \"Pro\" implies a professional-level product.\n",
|
||||||
|
"8. Football Fusion: This name could suggest a combination of different materials or technologies to create a high-quality football helmet.\n",
|
||||||
|
"9. Endzone Edge: \"Endzone\" is the area of the football field where a team scores a touchdown, and \"Edge\" implies a competitive advantage.\n",
|
||||||
|
"10. MVP Masks: \"MVP\" stands for \"Most Valuable Player,\" and \"Masks\" highlights the protective nature of the company's football helmets.\n",
|
||||||
|
"\n",
|
||||||
|
"Remember, the name you choose for your company should be memorable, easy to pronounce and spell, and convey a sense of quality and professionalism. It's also important to check that the name isn't already in use by another company, and to consider any potential trademark issues.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"human_message_prompt = HumanMessagePromptTemplate(\n",
|
||||||
|
" prompt=PromptTemplate(\n",
|
||||||
|
" template=\"What is a good name for a company that makes {product}?\",\n",
|
||||||
|
" input_variables=[\"product\"],\n",
|
||||||
|
" )\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])\n",
|
||||||
|
"chat = Fireworks()\n",
|
||||||
|
"chain = LLMChain(llm=chat, prompt=chat_prompt_template)\n",
|
||||||
|
"output = chain.run(\"football helmets\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(output)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
22
docs/extras/integrations/providers/fireworks.md
Normal file
22
docs/extras/integrations/providers/fireworks.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Fireworks
|
||||||
|
|
||||||
|
This page covers how to use the Fireworks models within Langchain.
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
|
||||||
|
- To use the Fireworks model, you need to have a Fireworks API key. To generate one, sign up at platform.fireworks.ai
|
||||||
|
- Authenticate by setting the FIREWORKS_API_KEY environment variable.
|
||||||
|
|
||||||
|
## LLM
|
||||||
|
|
||||||
|
Fireworks integrates with Langchain through the LLM module, which allows for standardized usage of any models deployed on the Fireworks models.
|
||||||
|
|
||||||
|
In this example, we'll work the llama-v2-13b.
|
||||||
|
```python
|
||||||
|
from langchain.llms.fireworks import Fireworks
|
||||||
|
|
||||||
|
llm = Fireworks(model="fireworks-llama-v2-13b-chat", max_tokens=256, temperature=0.4)
|
||||||
|
llm("Name 3 sports.")
|
||||||
|
```
|
||||||
|
|
||||||
|
For a more detailed walkthrough, see [here](/docs/extras/modules/model_io/models/llms/integrations/Fireworks.ipynb).
|
@ -39,6 +39,7 @@ from langchain.llms.ctransformers import CTransformers
|
|||||||
from langchain.llms.databricks import Databricks
|
from langchain.llms.databricks import Databricks
|
||||||
from langchain.llms.deepinfra import DeepInfra
|
from langchain.llms.deepinfra import DeepInfra
|
||||||
from langchain.llms.fake import FakeListLLM
|
from langchain.llms.fake import FakeListLLM
|
||||||
|
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||||
from langchain.llms.forefrontai import ForefrontAI
|
from langchain.llms.forefrontai import ForefrontAI
|
||||||
from langchain.llms.google_palm import GooglePalm
|
from langchain.llms.google_palm import GooglePalm
|
||||||
from langchain.llms.gooseai import GooseAI
|
from langchain.llms.gooseai import GooseAI
|
||||||
@ -98,6 +99,8 @@ __all__ = [
|
|||||||
"Databricks",
|
"Databricks",
|
||||||
"DeepInfra",
|
"DeepInfra",
|
||||||
"FakeListLLM",
|
"FakeListLLM",
|
||||||
|
"Fireworks",
|
||||||
|
"FireworksChat",
|
||||||
"ForefrontAI",
|
"ForefrontAI",
|
||||||
"GPT4All",
|
"GPT4All",
|
||||||
"GooglePalm",
|
"GooglePalm",
|
||||||
|
377
libs/langchain/langchain/llms/fireworks.py
Normal file
377
libs/langchain/langchain/llms/fireworks.py
Normal file
@ -0,0 +1,377 @@
|
|||||||
|
"""Wrapper around Fireworks APIs"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFireworks(BaseLLM):
|
||||||
|
"""Wrapper around Fireworks large language models."""
|
||||||
|
|
||||||
|
model_id: str = Field("fireworks-llama-v2-7b-chat", alias="model")
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
max_tokens: int = 512
|
||||||
|
"""The maximum number of tokens to generate in the completion.
|
||||||
|
-1 returns as many tokens as possible given the prompt and
|
||||||
|
the models maximal context size."""
|
||||||
|
top_p: float = 1
|
||||||
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
|
fireworks_api_key: Optional[str] = None
|
||||||
|
"""Api key to use fireworks API"""
|
||||||
|
batch_size: int = 20
|
||||||
|
"""Batch size to use when passing multiple documents to generate."""
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||||
|
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __new__(cls, **data: Any) -> Any:
|
||||||
|
"""Initialize the Fireworks object."""
|
||||||
|
data.get("model_id", "")
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Call out to Fireworks endpoint with k unique prompts.
|
||||||
|
Args:
|
||||||
|
prompts: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
Returns:
|
||||||
|
The full LLM output.
|
||||||
|
"""
|
||||||
|
params = {"model": self.model_id}
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||||
|
choices = []
|
||||||
|
token_usage: Dict[str, int] = {}
|
||||||
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||||
|
for _prompts in sub_prompts:
|
||||||
|
response = completion_with_retry(self, prompt=prompts, **params)
|
||||||
|
choices.extend(response)
|
||||||
|
update_token_usage(_keys, response, token_usage)
|
||||||
|
|
||||||
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Call out to Fireworks endpoint async with k unique prompts."""
|
||||||
|
params = {"model": self.model_id}
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
sub_prompts = self.get_batch_prompts(params, prompts, stop)
|
||||||
|
choices = []
|
||||||
|
token_usage: Dict[str, int] = {}
|
||||||
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||||
|
for _prompts in sub_prompts:
|
||||||
|
response = await acompletion_with_retry(self, prompt=_prompts, **params)
|
||||||
|
choices.extend(response)
|
||||||
|
update_token_usage(_keys, response, token_usage)
|
||||||
|
|
||||||
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
|
def get_batch_prompts(
|
||||||
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""Get the sub prompts for llm call."""
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
|
||||||
|
sub_prompts = [
|
||||||
|
prompts[i : i + self.batch_size]
|
||||||
|
for i in range(0, len(prompts), self.batch_size)
|
||||||
|
]
|
||||||
|
return sub_prompts
|
||||||
|
|
||||||
|
def create_llm_result(
|
||||||
|
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Create the LLMResult from the choices and prompts."""
|
||||||
|
generations = []
|
||||||
|
|
||||||
|
for i, _ in enumerate(prompts):
|
||||||
|
sub_choices = choices[i : (i + 1)]
|
||||||
|
generations.append(
|
||||||
|
[
|
||||||
|
Generation(
|
||||||
|
text=choice,
|
||||||
|
)
|
||||||
|
for choice in sub_choices
|
||||||
|
]
|
||||||
|
)
|
||||||
|
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
|
||||||
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "fireworks"
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksChat(BaseLLM):
|
||||||
|
"""Wrapper around Fireworks Chat large language models.
|
||||||
|
To use, you should have the ``fireworksai`` python package installed, and the
|
||||||
|
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||||
|
Any parameters that are valid to be passed to the fireworks.create
|
||||||
|
call can be passed in, even if not explicitly saved on this class.
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from langchain.llms import FireworksChat
|
||||||
|
fireworkschat = FireworksChat(model_id=""fireworks-llama-v2-13b-chat"")
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = "fireworks-llama-v2-7b-chat"
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
max_tokens: int = 512
|
||||||
|
"""The maximum number of tokens to generate in the completion.
|
||||||
|
-1 returns as many tokens as possible given the prompt and
|
||||||
|
the models maximal context size."""
|
||||||
|
top_p: float = 1
|
||||||
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
|
fireworks_api_key: Optional[str] = None
|
||||||
|
max_retries: int = 6
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||||
|
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
prefix_messages: List = Field(default_factory=list)
|
||||||
|
"""Series of messages for Chat input."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment"""
|
||||||
|
values["fireworks_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _get_chat_params(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> Tuple:
|
||||||
|
if len(prompts) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"FireworksChat currently only supports single prompt, got {prompts}"
|
||||||
|
)
|
||||||
|
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
||||||
|
params: Dict[str, Any] = {**{"model": self.model_id}}
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
|
||||||
|
return messages, params
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
full_response = completion_with_retry(self, messages=messages, **params)
|
||||||
|
llm_output = {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
}
|
||||||
|
return LLMResult(
|
||||||
|
generations=[[Generation(text=full_response[0])]],
|
||||||
|
llm_output=llm_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
full_response = await acompletion_with_retry(self, messages=messages, **params)
|
||||||
|
llm_output = {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
}
|
||||||
|
return LLMResult(
|
||||||
|
generations=[[Generation(text=full_response[0])]],
|
||||||
|
llm_output=llm_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "fireworks-chat"
|
||||||
|
|
||||||
|
|
||||||
|
class Fireworks(BaseFireworks):
|
||||||
|
"""Wrapper around Fireworks large language models.
|
||||||
|
To use, you should have the ``fireworks`` python package installed, and the
|
||||||
|
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||||
|
Any parameters that are valid to be passed to the fireworks.create
|
||||||
|
call can be passed in, even if not explicitly saved on this class.
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from langchain.llms import fireworks
|
||||||
|
llm = Fireworks(model_id="fireworks-llama-v2-13b")
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def update_token_usage(
|
||||||
|
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Update token usage."""
|
||||||
|
_keys_to_use = keys.intersection(response)
|
||||||
|
for _key in _keys_to_use:
|
||||||
|
if _key not in token_usage:
|
||||||
|
token_usage[_key] = response["usage"][_key]
|
||||||
|
else:
|
||||||
|
token_usage[_key] += response["usage"][_key]
|
||||||
|
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
max_tokens: int = 256,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute LLM query"""
|
||||||
|
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
|
||||||
|
requestBody = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
}
|
||||||
|
requestHeaders = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(
|
||||||
|
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
if "prompt" not in kwargs.keys():
|
||||||
|
answers = []
|
||||||
|
for i in range(len(kwargs["messages"])):
|
||||||
|
result = kwargs["messages"][i]["content"]
|
||||||
|
result = execute(
|
||||||
|
result,
|
||||||
|
kwargs["model"],
|
||||||
|
llm.fireworks_api_key,
|
||||||
|
llm.max_tokens,
|
||||||
|
llm.temperature,
|
||||||
|
llm.top_p,
|
||||||
|
)
|
||||||
|
curr_string = json.loads(result)["choices"][0]["text"]
|
||||||
|
answers.append(curr_string)
|
||||||
|
else:
|
||||||
|
answers = []
|
||||||
|
for i in range(len(kwargs["prompt"])):
|
||||||
|
result = kwargs["prompt"][i]
|
||||||
|
result = execute(
|
||||||
|
result,
|
||||||
|
kwargs["model"],
|
||||||
|
llm.fireworks_api_key,
|
||||||
|
llm.max_tokens,
|
||||||
|
llm.temperature,
|
||||||
|
llm.top_p,
|
||||||
|
)
|
||||||
|
curr_string = json.loads(result)["choices"][0]["text"]
|
||||||
|
answers.append(curr_string)
|
||||||
|
return answers
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(
|
||||||
|
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
if "prompt" not in kwargs.keys():
|
||||||
|
answers = []
|
||||||
|
for i in range(len(kwargs["messages"])):
|
||||||
|
result = kwargs["messages"][i]["content"]
|
||||||
|
result = execute(
|
||||||
|
result,
|
||||||
|
kwargs["model"],
|
||||||
|
llm.fireworks_api_key,
|
||||||
|
llm.max_tokens,
|
||||||
|
llm.temperature,
|
||||||
|
)
|
||||||
|
curr_string = json.loads(result)["choices"][0]["text"]
|
||||||
|
answers.append(curr_string)
|
||||||
|
else:
|
||||||
|
answers = []
|
||||||
|
for i in range(len(kwargs["prompt"])):
|
||||||
|
result = kwargs["prompt"][i]
|
||||||
|
result = execute(
|
||||||
|
result,
|
||||||
|
kwargs["model"],
|
||||||
|
llm.fireworks_api_key,
|
||||||
|
llm.max_tokens,
|
||||||
|
llm.temperature,
|
||||||
|
)
|
||||||
|
curr_string = json.loads(result)["choices"][0]["text"]
|
||||||
|
answers.append(curr_string)
|
||||||
|
return answers
|
157
libs/langchain/tests/integration_tests/llms/test_fireworks.py
Normal file
157
libs/langchain/tests/integration_tests/llms/test_fireworks.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
"""Test Fireworks AI API Wrapper."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain import LLMChain, PromptTemplate
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
from langchain.document_loaders import TextLoader
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.llms import OpenAIChat
|
||||||
|
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||||
|
from langchain.llms.loading import load_llm
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from langchain.vectorstores import DeepLake
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_call() -> None:
|
||||||
|
"""Test valid call to fireworks."""
|
||||||
|
llm = Fireworks(model_id="fireworks-llama-v2-13b-chat", max_tokens=900)
|
||||||
|
output = llm("What is the weather in NYC")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_in_chain() -> None:
|
||||||
|
"""Tests fireworks AI in a Langchain chain"""
|
||||||
|
human_message_prompt = HumanMessagePromptTemplate(
|
||||||
|
prompt=PromptTemplate(
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
input_variables=["product"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
|
||||||
|
chat = Fireworks()
|
||||||
|
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
|
||||||
|
output = chain.run("football helmets")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_chat_async_generate() -> None:
|
||||||
|
"""Test async chat."""
|
||||||
|
llm = OpenAIChat(max_tokens=10)
|
||||||
|
output = await llm.agenerate(["Hello, how are you?"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_model_param() -> None:
|
||||||
|
"""Tests model parameters for Fireworks"""
|
||||||
|
llm = Fireworks(model="foo")
|
||||||
|
assert llm.model_id == "foo"
|
||||||
|
llm = Fireworks(model_id="foo")
|
||||||
|
assert llm.model_id == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworkschat_model_param() -> None:
|
||||||
|
"""Tests model parameters for FireworksChat"""
|
||||||
|
llm = FireworksChat(model="foo")
|
||||||
|
assert llm.model_id == "foo"
|
||||||
|
llm = FireworksChat(model_id="foo")
|
||||||
|
assert llm.model_id == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
|
"""Test saving/loading an Fireworks LLM."""
|
||||||
|
llm = Fireworks(max_tokens=10)
|
||||||
|
llm.save(file_path=tmp_path / "fireworks.yaml")
|
||||||
|
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
|
||||||
|
assert loaded_llm == llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_multiple_prompts() -> None:
|
||||||
|
"""Test completion with multiple prompts."""
|
||||||
|
llm = Fireworks()
|
||||||
|
output = llm.generate(["How is the weather in New York today?", "I'm pickle rick"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert isinstance(output.generations, list)
|
||||||
|
assert len(output.generations) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworks_chat() -> None:
|
||||||
|
"""Test FireworksChat."""
|
||||||
|
llm = FireworksChat()
|
||||||
|
output = llm("Name me 3 quick facts about the New England Patriots")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fireworks_agenerate() -> None:
|
||||||
|
llm = Fireworks()
|
||||||
|
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert isinstance(output.generations, list)
|
||||||
|
assert len(output.generations) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fireworkschat_agenerate() -> None:
|
||||||
|
llm = FireworksChat(max_tokens=10)
|
||||||
|
output = await llm.agenerate(["Hello, how are you?"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert isinstance(output.generations, list)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_fireworkschat_chain() -> None:
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
|
||||||
|
loader = TextLoader(
|
||||||
|
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
|
||||||
|
)
|
||||||
|
documents = loader.load()
|
||||||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
docs = text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
|
||||||
|
db = DeepLake(
|
||||||
|
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
|
||||||
|
)
|
||||||
|
db.add_documents(docs)
|
||||||
|
|
||||||
|
query = "What did the president say about Ketanji Brown Jackson"
|
||||||
|
docs = db.similarity_search(query)
|
||||||
|
|
||||||
|
qa = RetrievalQA.from_chain_type(
|
||||||
|
llm=FireworksChat(),
|
||||||
|
chain_type="stuff",
|
||||||
|
retriever=db.as_retriever(),
|
||||||
|
)
|
||||||
|
query = "What did the president say about Ketanji Brown Jackson"
|
||||||
|
output = qa.run(query)
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
_EXPECTED_NUM_TOKENS = {
|
||||||
|
"fireworks-llama-v2-13b": 17,
|
||||||
|
"fireworks-llama-v2-7b": 17,
|
||||||
|
"fireworks-llama-v2-13b-chat": 17,
|
||||||
|
"fireworks-llama-v2-7b-chat": 17,
|
||||||
|
}
|
||||||
|
|
||||||
|
_MODELS = models = [
|
||||||
|
"fireworks-llama-v2-13b",
|
||||||
|
"fireworks-llama-v2-7b",
|
||||||
|
"fireworks-llama-v2-13b-chat",
|
||||||
|
"fireworks-llama-v2-7b-chat",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", _MODELS)
|
||||||
|
def test_fireworks_get_num_tokens(model: str) -> None:
|
||||||
|
"""Test get_tokens."""
|
||||||
|
llm = Fireworks(model=model)
|
||||||
|
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
Loading…
Reference in New Issue
Block a user