Refactor Fireworks and add ChatFireworks (#3) (#10597)

Description 
* Refactor Fireworks within Langchain LLMs.
* Remove FireworksChat within Langchain LLMs.
* Add ChatFireworks (which uses chat completion api) to Langchain chat
models.
* Users have to install `fireworks-ai` and register an api key to use
the api.

Issue - Not applicable
Dependencies - None
Tag maintainer - @rlancemartin @baskaryan
pull/11117/head
Cynthia Yang 10 months ago committed by GitHub
parent 5514ebe859
commit 6dd44ff1c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,255 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "642fd21c-600a-47a1-be96-6e1438b421a9",
"metadata": {},
"source": [
"# ChatFireworks\n",
"\n",
">[Fireworks](https://app.fireworks.ai/) accelerates product development on generative AI by creating an innovative AI experiment and production platform. \n",
"\n",
"This example goes over how to use LangChain to interact with `ChatFireworks` models."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d00d850917865298",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"from langchain.chat_models.fireworks import ChatFireworks\n",
"from langchain.schema import SystemMessage, HumanMessage\n",
"import os"
]
},
{
"cell_type": "markdown",
"id": "f28ebf8b-f14f-46c7-9962-8b8dc42e31be",
"metadata": {},
"source": [
"# Setup\n",
"Contact Fireworks AI for the an API Key to access our models\n",
"\n",
"Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d096fb14-8acc-4047-9cd0-c842430c3a1d",
"metadata": {},
"outputs": [],
"source": [
"# Initialize a Fireworks Chat model\n",
"os.environ['FIREWORKS_API_KEY'] = \"<your_api_key>\" # Change this to your own API key\n",
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
]
},
{
"cell_type": "markdown",
"id": "d8f13144-37cf-47a5-b5a0-e3cdf76d9a72",
"metadata": {},
"source": [
"# Calling the Model\n",
"\n",
"You can use the LLMs to call the model for specified message(s). \n",
"\n",
"See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "72340871-ae2f-415f-b399-0777d32dc379",
"metadata": {},
"outputs": [],
"source": [
"# ChatFireworks Wrapper\n",
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
"human_message = HumanMessage(content=\"Who are you?\")\n",
"response = chat([system_message, human_message])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2d6ef879-69e3-422b-8379-bb980b70fe55",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. My primary function is to assist users with tasks and answer questions to the best of my ability. I am capable of understanding and responding to natural language input, and I am here to help you with any questions or tasks you may have. Is there anything specific you would like to know or discuss?\", additional_kwargs={}, example=False)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "68c6b1fa-2ff7-4a63-8d88-3cec302180b8",
"metadata": {},
"outputs": [],
"source": [
"# Setting additional parameters: temperature, max_tokens, top_p\n",
"chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n",
"system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
"human_message = HumanMessage(content=\"How's the weather today?\")\n",
"response = chat([system_message, human_message])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a09025f8-e4c3-4005-a8fc-c9c774b03a64",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Oh, you know, it's just another beautiful day in the virtual world! The sun\", additional_kwargs={}, example=False)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
},
{
"cell_type": "markdown",
"id": "d93aa186-39cf-4e1a-aa32-01ed31d43bc8",
"metadata": {},
"source": [
"# ChatFireworks Wrapper with generate"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589",
"metadata": {},
"outputs": [],
"source": [
"chat = ChatFireworks()\n",
"message = HumanMessage(content=\"Hello\")\n",
"response = chat.generate([[message], [message]])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "35109f36-9519-47a6-a223-25639123e836",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", additional_kwargs={}, example=False))], [ChatGeneration(text=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", additional_kwargs={}, example=False))]], llm_output={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, run=[RunInfo(run_id=UUID('f137463e-e1c7-454a-8b85-b999ce20e0f2')), RunInfo(run_id=UUID('f3ef1138-92de-4e01-900b-991e34a647a7'))])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
},
{
"cell_type": "markdown",
"id": "92c2cabb-9eaf-4c49-b0e5-a5de5a7d920e",
"metadata": {},
"source": [
"# ChatFireworks Wrapper with stream"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "12717a29-fb7d-4a4d-860b-40435452b065",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Hello! I'm just\n",
" an AI assistant,\n",
" here to help answer your\n",
" questions and provide information in\n",
" a responsible and respectful manner\n",
". I'm not able\n",
" to access personal information or provide\n",
" any content that could be considered\n",
" harmful, uneth\n",
"ical, racist, sex\n",
"ist, toxic, dangerous\n",
", or illegal. My purpose\n",
" is to assist and provide helpful\n",
" responses that are socially un\n",
"biased and positive in nature\n",
". Is there something specific you\n",
" would like to know or discuss\n",
"?\n"
]
}
],
"source": [
"llm = ChatFireworks()\n",
"\n",
"for token in llm.stream(\"Who are you\"):\n",
" print(token.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02991e05-a38e-47d4-9ab3-7e630a8ead55",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -19,7 +19,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.fireworks import Fireworks, FireworksChat\n",
"from langchain.llms.fireworks import Fireworks\n",
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
@ -48,8 +48,8 @@
"outputs": [],
"source": [
"# Initialize a Fireworks LLM\n",
"os.environ['FIREWORKS_API_KEY'] = \"<YOUR_API_KEY>\" # Change this to your own API key\n",
"llm = Fireworks(model_id=\"accounts/fireworks/models/llama-v2-13b-chat\")"
"os.environ['FIREWORKS_API_KEY'] = \"<your_api_key>\" # Change this to your own API key\n",
"llm = Fireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
]
},
{
@ -61,28 +61,7 @@
"\n",
"You can use the LLMs to call the model for specified prompt(s). \n",
"\n",
"Currently supported models: \n",
"\n",
"* Falcon\n",
" * `accounts/fireworks/models/falcon-7b`\n",
" * `accounts/fireworks/models/falcon-40b-w8a16`\n",
"* Llama 2\n",
" * `accounts/fireworks/models/llama-v2-7b`\n",
" * `accounts/fireworks/models/llama-v2-7b-w8a16`\n",
" * `accounts/fireworks/models/llama-v2-7b-chat`\n",
" * `accounts/fireworks/models/llama-v2-7b-chat-w8a16`\n",
" * `accounts/fireworks/models/llama-v2-13b`\n",
" * `accounts/fireworks/models/llama-v2-13b-w8a16`\n",
" * `accounts/fireworks/models/llama-v2-13b-chat`\n",
" * `accounts/fireworks/models/llama-v2-13b-chat-w8a16`\n",
" * `accounts/fireworks/models/llama-v2-70b-chat-4gpu`\n",
"* StarCoder\n",
" * `accounts/fireworks/models/starcoder-1b-w8a16-1gpu`\n",
" * `accounts/fireworks/models/starcoder-3b-w8a16-1gpu`\n",
" * `accounts/fireworks/models/starcoder-7b-w8a16-1gpu`\n",
" * `accounts/fireworks/models/starcoder-16b-w8a16`\n",
"\n",
"See the full, most up-to-date list on [app.fireworks.ai](https://app.fireworks.ai)."
"See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)."
]
},
{
@ -95,29 +74,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Is it Tom Brady, Aaron Rodgers, or someone else? It's a tough question to answer, and there are strong arguments for each of these quarterbacks. Here are some of the reasons why each of these quarterbacks could be considered the best:\n",
"\n",
"Tom Brady:\n",
"\n",
"* He has the most Super Bowl wins (6) of any quarterback in NFL history.\n",
"* He has been named Super Bowl MVP four times, more than any other player.\n",
"* He has led the New England Patriots to 18 playoff victories, the most in NFL history.\n",
"* He has thrown for over 70,000 yards in his career, the most of any quarterback in NFL history.\n",
"* He has thrown for 50 or more touchdowns in a season four times, the most of any quarterback in NFL history.\n",
"\n",
"Aaron Rodgers:\n",
"\n",
"* He has led the Green Bay Packers to a Super Bowl victory in 2010.\n",
"* He has been named Super Bowl MVP once.\n",
"* He has thrown for over 40,000 yards in his career, the most of any quarterback in NFL history.\n",
"* He has thrown for 40 or more touchdowns in a season three times, the most of any quarterback in NFL history.\n",
"* He has a career passer rating of 103.1, the highest of any quarterback in NFL history.\n",
"It's a question that's been debated for years, and there are plenty of strong candidates. Here are some of the top quarterbacks in the league right now:\n",
"\n",
"So, who's the best quarterback in the NFL? It's a tough call, but here's my opinion:\n",
"1. Tom Brady (New England Patriots): Brady is widely considered one of the greatest quarterbacks of all time, and for good reason. He's led the Patriots to six Super Bowl wins and has been named Super Bowl MVP four times. He's known for his precision passing and ability to read defenses.\n",
"2. Aaron Rodgers (Green Bay Packers): Rodgers is another top-tier quarterback who's known for his accuracy and ability to make plays outside of the pocket. He's led the Packers to a Super Bowl win and has been named NFL MVP twice.\n",
"3. Drew Brees (New Orleans Saints): Brees is one of the most prolific passers in NFL history, and he's shown no signs of slowing down. He's led the Saints to a Super Bowl win and has been named NFL MVP once.\n",
"4. Russell Wilson (Seattle Seahawks): Wilson is a dynamic quarterback who's known for his ability to make plays with his legs and his arm. He's led the Seahawks to a Super Bowl win and has been named NFL MVP once.\n",
"5. Patrick Mahomes (Kansas City Chiefs): Mahomes is a young quarterback who's quickly become one of the best in the league. He led the Chiefs to a Super Bowl win last season and has been named NFL MVP twice. He's known for his incredible arm talent and ability to make plays outside of the pocket.\n",
"\n",
"I think Aaron Rodgers is the best quarterback in the NFL right now. He has led the Packers to a Super Bowl victory and has had some incredible seasons, including the 2011 season when he threw for 45 touchdowns and just 6 interceptions. He has a strong arm, great accuracy, and is incredibly mobile for a quarterback of his size. He also has a great sense of timing and knows when to take risks and when to play it safe.\n",
"\n",
"Tom Brady is a close second, though. He has an incredible track record of success, including six Super Bowl victories, and has been one of the most consistent quarterbacks in the league for the past two decades. He has a strong arm and is incredibly accurate\n"
"Of course, there are other great quarterbacks in the league as well, such as Ben Roethlisberger, Matt Ryan, and Deshaun Watson. Ultimately, the \"best\" quarterback is a matter of personal opinion and depends on how you define \"best.\" Some people might value accuracy and precision passing, while others might prefer a quarterback who can make plays with their legs. Either way, the NFL is filled with talented quarterbacks who are making incredible plays every week.\n"
]
}
],
@ -137,7 +104,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[[Generation(text='\\nThe best cricket player in 2016 is a matter of opinion, but some of the top contenders for the title include:\\n\\n1. Virat Kohli (India): Kohli had a phenomenal year in 2016, scoring over 1,000 runs in Test cricket, including four centuries, and averaging over 70. He also scored heavily in ODI cricket, with an average of over 80.\\n2. Steve Smith (Australia): Smith had a remarkable year in 2016, leading Australia to a Test series victory in India and scoring over 1,000 runs in the format, including five centuries. He also averaged over 60 in ODI cricket.\\n3. KL Rahul (India): Rahul had a breakout year in 2016, scoring over 1,000 runs in Test cricket, including four centuries, and averaging over 60. He also scored heavily in ODI cricket, with an average of over 70.\\n4. Joe Root (England): Root had a solid year in 2016, scoring over 1,000 runs in Test cricket, including four centuries, and averaging over 50. He also scored heavily in ODI cricket, with an average of over 80.\\n5. Quinton de Kock (South Africa): De Kock had a remarkable year in 2016, scoring over 1,000 runs in ODI cricket, including six centuries, and averaging over 80. He also scored heavily in Test cricket, with an average of over 50.\\n\\nThese are just a few of the top contenders for the title of best cricket player in 2016, but there were many other talented players who also had impressive years. Ultimately, the answer to this question is subjective and depends on individual opinions and criteria for evaluation.', generation_info=None)], [Generation(text=\"\\nThis is a tough one, as there are so many great players in the league right now. But if I had to choose one, I'd say LeBron James is the best basketball player in the league. He's a once-in-a-generation talent who can dominate the game in so many ways. He's got incredible speed, strength, and court vision, and he's always finding new ways to improve his game. Plus, he's been doing it at an elite level for over a decade now, which is just amazing.\\n\\nBut don't just take my word for it - there are plenty of other great players in the league who could make a strong case for being the best. Guys like Kevin Durant, Steph Curry, James Harden, and Giannis Antetokounmpo are all having incredible seasons, and they've all got their own unique skills and strengths that make them special. So ultimately, it's up to you to decide who you think is the best basketball player in the league.\", generation_info=None)]]\n"
"[[Generation(text=\"\\n\\nNote: This is a subjective question, and the answer will depend on individual opinions and perspectives.\\n\\nThere are many great cricket players, and it's difficult to identify a single best player. However, here are some of the top performers in 2016:\\n\\n1. Virat Kohli (India): Kohli had an outstanding year in all formats of the game, scoring heavily in Tests, ODIs, and T20Is. He was especially impressive in the Test series against England, where he scored four centuries and averaged over 100.\\n2. Steve Smith (Australia): Smith had a phenomenal year as well, leading Australia to a Test series win in India and averaging over 100 in the longer format. He also scored a century in the ODI series against Pakistan.\\n3. Kane Williamson (New Zealand): Williamson had a consistent year, scoring heavily in all formats and leading New Zealand to a Test series win against Australia. He also won the ICC Test Player of the Year award.\\n4. Joe Root (England): Root had a solid year, scoring three hundreds in the Test series against Pakistan and India, and averaging over 50 in Tests.\\n5. AB de Villiers (South Africa): De Villiers had a brilliant year in ODIs, scoring four hundreds and averaging over 100. He also had a good year in Tests, scoring two hundreds and averaging over 50.\\n6. Quinton de Kock (South Africa): De Kock had a great year behind the wickets, scoring heavily in all formats and averaging over 50 in Tests.\\n7. Rohit Sharma (India): Sharma had a fantastic year in ODIs, scoring four hundreds and averaging over 100. He also had a good year in Tests, scoring two hundreds and averaging over 40.\\n8. David Warner (Australia): Warner had a great year in ODIs, scoring three hundreds and averaging over 100. He also had a good year in Tests, scoring two hundreds and averaging over 40.\\n\\nThese are just a few examples of top performers in 2016, and opinions on the best player will vary depending on individual perspectives\", generation_info=None)], [Generation(text='\\n\\nThere are a lot of great players in the NBA, and opinions on who\\'s the best can vary depending on personal preferences and criteria for evaluation. However, here are some of the top candidates for the title of best basketball player in the league based on their recent performances and achievements:\\n\\n1. LeBron James: James is a four-time NBA champion and four-time MVP, and is widely regarded as one of the greatest players of all time. He has led the Los Angeles Lakers to the best record in the Western Conference this season and is averaging 25.7 points, 7.9 rebounds, and 7.4 assists per game.\\n2. Giannis Antetokounmpo: Antetokounmpo, known as the \"Greek Freak,\" is a dominant force in the paint and has led the Milwaukee Bucks to the best record in the Eastern Conference. He is averaging 30.5 points, 12.6 rebounds, and 5.9 assists per game, and is a strong contender for the MVP award.\\n3. Stephen Curry: Curry is a three-time NBA champion and two-time MVP, and is known for his incredible shooting ability. He has led the Golden State Warriors to the playoffs despite injuries to key players, and is averaging 23.5 points, 5.2 rebounds, and 5.2 assists per game.\\n4. Kevin Durant: Durant is a two-time NBA champion and four-time scoring champion, and is one of the most skilled scorers in the league. He has led the Brooklyn Nets to the playoffs in their first season since moving from New Jersey, and is averaging 27.2 points, 7.2 rebounds, and 6.4 assists per game.\\n5. James Harden: Harden is a three-time scoring champion and has led the Houston Rockets to the playoffs for the past eight seasons. He is averaging 35.4 points, 8.3 rebounds, and 7.5 assists per game, and is a strong contender for the MVP award.\\n\\nUltimately, determining the best basketball player in the league is subjective and depends on individual opinions and criteria. However, these five players are among', generation_info=None)]]\n"
]
}
],
@ -161,13 +128,13 @@
"output_type": "stream",
"text": [
"\n",
"Kansas City in December is quite cold, with temperatures typically r\n"
"Kansas City's weather in December can be quite chilly,\n"
]
}
],
"source": [
"# Setting additional parameters: temperature, max_tokens, top_p\n",
"llm = Fireworks(model_id=\"accounts/fireworks/models/llama-v2-13b-chat\", temperature=0.7, max_tokens=15, top_p=1.0)\n",
"llm = Fireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":0.7, \"max_tokens\":15, \"top_p\":1.0})\n",
"print(llm(\"What's the weather like in Kansas City in December?\"))"
]
},
@ -192,30 +159,140 @@
"output_type": "stream",
"text": [
"\n",
"Naming a company can be a fun and creative process! Here are a few name ideas for a company that makes football helmets:\n",
"\n",
"1. Helix Headgear: This name plays off the idea of the helix shape of a football helmet and could be a memorable and catchy name for a company.\n",
"2. Gridiron Gear: \"Gridiron\" is a term used to describe a football field, and \"gear\" refers to the products the company sells. This name is straightforward and easy to understand.\n",
"3. Cushion Crusaders: This name emphasizes the protective qualities of football helmets and could appeal to customers looking for safety-conscious products.\n",
"4. Helmet Heroes: This name has a fun, heroic tone and could appeal to customers looking for high-quality products.\n",
"5. Tackle Tech: \"Tackle\" is a term used in football to describe a player's attempt to stop an opponent, and \"tech\" refers to the technology used in the helmets. This name could appeal to customers interested in innovative products.\n",
"6. Padded Protection: This name emphasizes the protective qualities of football helmets and could appeal to customers looking for products that prioritize safety.\n",
"7. Gridiron Gear Co.: This name is simple and straightforward, and it clearly conveys the company's focus on football-related products.\n",
"8. Helmet Haven: This name has a soothing, protective tone and could appeal to customers looking for a reliable brand.\n",
"Assistant: That's a great question! There are many factors to consider when choosing a name for a company that makes football helmets. Here are a few suggestions:\n",
"\n",
"1. Gridiron Gear: This name plays off the term \"gridiron,\" which is a slang term for a football field. It also suggests that the company's products are high-quality and durable, like gear used in a gridiron game.\n",
"2. Helmet Headquarters: This name is straightforward and to the point. It clearly communicates that the company is a leading manufacturer of football helmets.\n",
"3. Tackle Tough: This name plays off the idea of tackling a tough opponent on the football field. It suggests that the company's helmets are designed to protect players from even the toughest hits.\n",
"4. Block Breakthrough: This name is a play on words that suggests the company's helmets are breaking through the competition. It also implies that the company is innovative and forward-thinking.\n",
"5. First Down Fashion: This name combines the idea of scoring a first down on the football field with the idea of fashionable clothing. It suggests that the company's helmets are not only functional but also stylish.\n",
"\n",
"Remember to choose a name that reflects your company's values and mission, and that resonates with your target market. Good luck with your company!\n"
"I hope these suggestions help you come up with a great name for your company!\n"
]
}
],
"source": [
"human_message_prompt = HumanMessagePromptTemplate.from_template(\"What is a good name for a company that makes {product}?\")\n",
"chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])\n",
"chat = FireworksChat()\n",
"chat = Fireworks()\n",
"chain = LLMChain(llm=chat, prompt=chat_prompt_template)\n",
"output = chain.run(\"football helmets\")\n",
"\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"id": "25812db3-23a6-41dd-8636-5a49c52bb6eb",
"metadata": {},
"source": [
"# Run Stream"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "26d67ecf-9290-4ec2-8b39-ff17fc99620f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Tom Brady, Aaron Rod\n",
"gers, or Drew Bre\n",
"es?\n",
"Some people might\n",
" say Tom Brady, who\n",
" has won six Super Bowls\n",
" and four Super Bowl MVP\n",
" awards, is the best quarter\n",
"back in the NFL. O\n",
"thers might argue that Aaron\n",
" Rodgers, who has led\n",
" his team to a Super Bowl\n",
" victory and has been named the\n",
" NFL MVP twice, is\n",
" the best. Still, others\n",
" might say that Drew Bre\n",
"es, who holds the NFL\n",
" record for most career passing yards\n",
" and has led his team to\n",
" a Super Bowl victory, is\n",
" the best.\n",
"But what\n",
" if I told you there'\n",
"s actually a fourth quarterback\n",
" who could make a strong case\n",
" for being the best in the\n",
" NFL? Meet Russell Wilson\n",
", the Seattle Seahaw\n",
"ks' dynamic signal-call\n",
"er who has led his team\n",
" to a Super Bowl victory and\n",
" has been named the NFL M\n",
"VP twice.\n",
"Wilson\n",
" has a unique combination of physical\n",
" and mental skills that set him\n",
" apart from other quarterbacks\n",
" in the league. He'\n",
"s incredibly athletic,\n",
" with the ability to make plays\n",
" with his feet and his arm\n",
", and he's also\n",
" highly intelligent, with a\n",
" quick mind and the ability to\n",
" read defenses like a pro\n",
".\n",
"But what really\n",
" sets Wilson apart is his\n",
" leadership ability. He'\n",
"s a natural-born\n",
" leader who has a way\n",
" of inspiring his team\n",
"mates and getting them\n",
" to buy into his vision\n",
" for the game. He\n",
"'s also an excellent\n",
" communicator, who can\n",
" articulate his strategy\n",
" and game plan in a\n",
" way that his teamm\n",
"ates can understand and execute\n",
".\n",
"So, who\n",
"'s the best quarter\n",
"back in the NFL?\n",
" It's hard to\n",
" say for sure, but\n",
" if you ask me,\n",
" Russell Wilson is definitely in\n",
" the conversation. He'\n",
"s got the physical skills\n",
", the mental skills,\n",
" and the leadership ability to\n",
" be the best of the\n",
" best.\n"
]
}
],
"source": [
"llm = Fireworks()\n",
"generator = llm.stream(\"Who's the best quarterback in the NFL?\")\n",
"\n",
"for token in generator:\n",
" print(token)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3a35e0b-c875-493a-8143-d802d273247c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

@ -0,0 +1,264 @@
import fireworks
import fireworks.client
from langchain.utils.env import get_from_dict_or_env
from pydantic import root_validator
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from langchain.adapters.openai import convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import create_base_retry_decorator
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
"""Convert a delta response to a message chunk."""
role = _dict.role
content = _dict.content or ""
additional_kwargs = {}
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict.name)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dict response to a message."""
role = _dict.role
content = _dict.content or ""
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
content = _dict.content
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=content)
elif role == "function":
return FunctionMessage(content=content, name=_dict.name)
else:
return ChatMessage(content=content, role=role)
class ChatFireworks(BaseChatModel):
"""Fireworks Chat models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
fireworks_api_key: Optional[str] = None
max_retries: int = 20
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""
fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
fireworks.client.api_key = fireworks_api_key
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages, stop)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
}
response = completion_with_retry(self, **params)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages, stop)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
}
response = await acompletion_with_retry(self, **params)
return self._create_chat_result(response)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return llm_outputs[0]
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response.choices:
message = convert_dict_to_message(res.message)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.finish_reason),
)
generations.append(gen)
llm_output = {"model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]]]:
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages, stop)
default_chunk_class = AIMessageChunk
params = {
"model": self.model,
"messages": message_dicts,
"stream": True,
**self.model_kwargs,
}
for chunk in completion_with_retry(self, **params):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
finish_reason = choice.finish_reason
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages, stop)
default_chunk_class = AIMessageChunk
params = {
"model": self.model,
"messages": message_dicts,
"stream": True,
**self.model_kwargs,
}
async for chunk in await acompletion_with_retry_streaming(self, **params):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
finish_reason = choice.finish_reason
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
def completion_with_retry(
llm: ChatFireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: ChatFireworks,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.ChatCompletion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry_streaming(
llm: ChatFireworks,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call for streaming."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def _create_retry_decorator(
llm: ChatFireworks,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Define retry mechanism."""
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)

@ -44,7 +44,7 @@ from langchain.llms.deepinfra import DeepInfra
from langchain.llms.deepsparse import DeepSparse
from langchain.llms.edenai import EdenAI
from langchain.llms.fake import FakeListLLM
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.fireworks import Fireworks
from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI

@ -1,377 +1,220 @@
"""Wrapper around Fireworks APIs"""
import json
import logging
from typing import (
Any,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
import requests
import fireworks
import fireworks.client
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class BaseFireworks(BaseLLM):
"""Wrapper around Fireworks large language models."""
model_id: str = Field("accounts/fireworks/models/llama-v2-7b-chat", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.output import GenerationChunk
from langchain.schema.runnable.config import RunnableConfig
from langchain.utils.env import get_from_dict_or_env
from pydantic import root_validator
def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
return GenerationChunk(
text=stream_response.choices[0].text,
generation_info=dict(
finish_reason=stream_response.choices[0].finish_reason,
logprobs=stream_response.choices[0].logprobs,
),
)
class Fireworks(LLM):
"""Fireworks models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1}
fireworks_api_key: Optional[str] = None
"""Api key to use fireworks API"""
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def __new__(cls, **data: Any) -> Any:
"""Initialize the Fireworks object."""
data.get("model_id", "")
return super().__new__(cls)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
max_retries: int = 20
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["fireworks_api_key"] = get_from_dict_or_env(
"""Validate that api key in environment."""
fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
fireworks.client.api_key = fireworks_api_key
return values
def _generate(
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks"
def _call(
self,
prompts: List[str],
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = completion_with_retry(self, prompt=prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
) -> str:
"""Run the LLM on the given prompt and input."""
params = {
"model": self.model,
"prompt": prompt,
**self.model_kwargs,
}
response = completion_with_retry(self, **params)
return self.create_llm_result(choices, prompts, token_usage)
return response.choices[0].text
async def _agenerate(
async def _acall(
self,
prompts: List[str],
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {"model": self.model_id}
params = {**params, **kwargs}
sub_prompts = self.get_batch_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response)
update_token_usage(_keys, response, token_usage)
) -> str:
"""Run the LLM on the given prompt and input."""
params = {
"model": self.model,
"prompt": prompt,
**self.model_kwargs,
}
response = await acompletion_with_retry(self, **params)
return self.create_llm_result(choices, prompts, token_usage)
return response.choices[0].text
def get_batch_prompts(
def _stream(
self,
params: Dict[str, Any],
prompts: List[str],
prompt: str,
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice,
)
for choice in sub_choices
]
)
llm_output = {"token_usage": token_usage, "model_id": self.model_id}
return LLMResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks"
class FireworksChat(BaseLLM):
"""Wrapper around Fireworks Chat large language models.
To use, you should have the ``fireworksai`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import FireworksChat
fireworkschat = FireworksChat(model_id=""llama-v2-13b-chat"")
"""
model_id: str = "accounts/fireworks/models/llama-v2-7b-chat"
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
fireworks_api_key: Optional[str] = None
max_retries: int = 6
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Fireworks completion API. Default is 600 seconds."""
"""Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list)
"""Series of messages for Chat input."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment"""
values["fireworks_api_key"] = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
return values
def _get_chat_params(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> Tuple:
if len(prompts) > 1:
raise ValueError(
f"FireworksChat currently only supports single prompt, got {prompts}"
)
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = {**{"model": self.model_id}}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
return messages, params
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self.model_kwargs,
}
for stream_resp in completion_with_retry(self, **params):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
def _generate(
async def _astream(
self,
prompts: List[str],
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
) -> Iterator[GenerationChunk]:
params = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self.model_kwargs,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
async for stream_resp in await acompletion_with_retry_streaming(self, **params):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
async def _agenerate(
def stream(
self,
prompts: List[str],
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, messages=messages, **params)
llm_output = {
"model_id": self.model_id,
}
return LLMResult(
generations=[[Generation(text=full_response[0])]],
llm_output=llm_output,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
) -> Iterator[str]:
prompt = self._convert_input(input).to_string()
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
prompt = self._convert_input(input).to_string()
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompt):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
class Fireworks(BaseFireworks):
"""Wrapper around Fireworks large language models.
To use, you should have the ``fireworks`` python package installed, and the
environment variable ``FIREWORKS_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the fireworks.create
call can be passed in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.llms import fireworks
llm = Fireworks(model_id="llama-v2-13b")
"""
def completion_with_retry(
llm: Fireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.create(
**kwargs,
)
def update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response)
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
return _completion_with_retry(**kwargs)
def execute(
prompt: str,
model: str,
api_key: Optional[str],
max_tokens: int = 256,
temperature: float = 0.0,
top_p: float = 1.0,
async def acompletion_with_retry(
llm: Fireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Execute LLM query"""
requestUrl = "https://api.fireworks.ai/inference/v1/completions"
requestBody = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
requestHeaders = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
response = requests.post(requestUrl, headers=requestHeaders, json=requestBody)
return response.text
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.Completion.acreate(
**kwargs,
)
def completion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
llm.top_p,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: Union[BaseFireworks, FireworksChat], **kwargs: Any
async def acompletion_with_retry_streaming(
llm: Fireworks,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
if "prompt" not in kwargs.keys():
answers = []
for i in range(len(kwargs["messages"])):
result = kwargs["messages"][i]["content"]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
else:
answers = []
for i in range(len(kwargs["prompt"])):
result = kwargs["prompt"][i]
result = execute(
result,
kwargs["model"],
llm.fireworks_api_key,
llm.max_tokens,
llm.temperature,
)
curr_string = json.loads(result)["choices"][0]["text"]
answers.append(curr_string)
return answers
"""Use tenacity to retry the completion call for streaming."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def _create_retry_decorator(
llm: Fireworks,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Define retry mechanism."""
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -2390,6 +2390,21 @@ calc = ["shapely"]
s3 = ["boto3 (>=1.3.1)"]
test = ["Fiona[s3]", "pytest (>=7)", "pytest-cov", "pytz"]
[[package]]
name = "fireworks-ai"
version = "0.4.1"
description = "Python client library for the Fireworks.ai Generative AI Platform"
optional = true
python-versions = ">=3.9"
files = [
{file = "fireworks_ai-0.4.1-py3-none-any.whl", hash = "sha256:6ac124ffcd783442b4569e4127adafb0bde6861b6ce5d7a7d162d3920e7cc4e9"},
]
[package.dependencies]
httpx = "*"
httpx-sse = "*"
pydantic = "*"
[[package]]
name = "flatbuffers"
version = "23.5.26"
@ -3171,6 +3186,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]]
name = "httpx-sse"
version = "0.3.1"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = true
python-versions = ">=3.7"
files = [
{file = "httpx-sse-0.3.1.tar.gz", hash = "sha256:3bb3289b2867f50cbdb2fee3eeeefecb1e86653122e164faac0023f1ffc88aea"},
{file = "httpx_sse-0.3.1-py3-none-any.whl", hash = "sha256:7376dd88732892f9b6b549ac0ad05a8e2341172fe7dcf9f8f9c8050934297316"},
]
[[package]]
name = "huggingface-hub"
version = "0.16.4"
@ -5818,7 +5844,7 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
]
python-dateutil = ">=2.8.2"
@ -8820,7 +8846,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""}
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
typing-extensions = ">=4.2.0"
[package.extras]
@ -10622,4 +10648,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85"
content-hash = "2d24ce7353641663405c132acfbd45492f56c4e53372eac4b698a6ace1eb27b7"

@ -132,6 +132,7 @@ sqlite-vss = {version = "^0.1.2", optional = true}
anyio = "<4.0"
jsonpatch = "^1.33"
timescale-vector = {version = "^0.0.1", optional = true}
fireworks-ai = {version = "^0.4.1", optional = true, python = ">=3.9"}
[tool.poetry.group.test.dependencies]

@ -0,0 +1,106 @@
"""Test ChatFireworks wrapper."""
import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.fireworks import ChatFireworks
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
def test_chat_fireworks() -> None:
"""Test ChatFireworks wrapper."""
chat = ChatFireworks()
message = HumanMessage(content="What is the weather in Redwood City, CA today")
response = chat([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_fireworks_model() -> None:
"""Test ChatFireworks wrapper handles model_name."""
chat = ChatFireworks(model="foo")
assert chat.model == "foo"
def test_chat_fireworks_system_message() -> None:
"""Test ChatFireworks wrapper with system message."""
chat = ChatFireworks()
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_fireworks_generate() -> None:
"""Test ChatFireworks wrapper with generate."""
chat = ChatFireworks(model_kwargs={"n": 2})
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_chat_fireworks_multiple_completions() -> None:
"""Test ChatFireworks wrapper with multiple completions."""
chat = ChatFireworks(model_kwargs={"n": 5})
message = HumanMessage(content="Hello")
response = chat._generate([message])
assert isinstance(response, ChatResult)
assert len(response.generations) == 5
for generation in response.generations:
assert isinstance(generation.message, BaseMessage)
assert isinstance(generation.message.content, str)
def test_chat_fireworks_llm_output_contains_model_id() -> None:
"""Test llm_output contains model_id."""
chat = ChatFireworks()
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model"] == chat.model
def test_fireworks_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatFireworks()
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_chat_fireworks_agenerate() -> None:
"""Test ChatFireworks wrapper with generate."""
chat = ChatFireworks(model_kwargs={"n": 2})
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_fireworks_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatFireworks()
async for token in llm.astream("Who's the best quarterback in the NFL?"):
assert isinstance(token.content, str)

@ -1,31 +1,20 @@
"""Test Fireworks AI API Wrapper."""
from pathlib import Path
import pytest
from langchain.chains import RetrievalQA
from langchain.chains.llm import LLMChain
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAIChat
from langchain.llms.fireworks import Fireworks, FireworksChat
from langchain.llms.loading import load_llm
from typing import Generator
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms.fireworks import Fireworks
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import LLMResult
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import DeepLake
import pytest
def test_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = Fireworks(
model_id="accounts/fireworks/models/fireworks-llama-v2-13b-chat", max_tokens=900
)
output = llm("What is the weather in NYC")
llm = Fireworks()
output = llm("Who's the best quarterback in the NFL?")
assert isinstance(output, str)
@ -44,36 +33,10 @@ def test_fireworks_in_chain() -> None:
assert isinstance(output, str)
@pytest.mark.asyncio
async def test_openai_chat_async_generate() -> None:
"""Test async chat."""
llm = OpenAIChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
def test_fireworks_model_param() -> None:
"""Tests model parameters for Fireworks"""
llm = Fireworks(model="foo")
assert llm.model_id == "foo"
llm = Fireworks(model_id="foo")
assert llm.model_id == "foo"
def test_fireworkschat_model_param() -> None:
"""Tests model parameters for FireworksChat"""
llm = FireworksChat(model="foo")
assert llm.model_id == "foo"
llm = FireworksChat(model_id="foo")
assert llm.model_id == "foo"
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an Fireworks LLM."""
llm = Fireworks(max_tokens=10)
llm.save(file_path=tmp_path / "fireworks.yaml")
loaded_llm = load_llm(tmp_path / "fireworks.yaml")
assert loaded_llm == llm
assert llm.model == "foo"
def test_fireworks_multiple_prompts() -> None:
@ -85,76 +48,39 @@ def test_fireworks_multiple_prompts() -> None:
assert len(output.generations) == 2
def test_fireworks_chat() -> None:
"""Test FireworksChat."""
llm = FireworksChat()
output = llm("Name me 3 quick facts about the New England Patriots")
assert isinstance(output, str)
async def test_fireworks_agenerate() -> None:
def test_fireworks_streaming() -> None:
"""Test stream completion."""
llm = Fireworks()
output = await llm.agenerate(["I'm a pickle", "I'm a pickle"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
generator = llm.stream("Who's the best quarterback in the NFL?")
assert isinstance(generator, Generator)
async def test_fireworkschat_agenerate() -> None:
llm = FireworksChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 1
for token in generator:
assert isinstance(token, str)
def test_fireworkschat_chain() -> None:
embeddings = OpenAIEmbeddings()
@pytest.mark.asyncio
async def test_fireworks_streaming_async() -> None:
"""Test stream completion."""
llm = Fireworks()
loader = TextLoader(
"[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt"
)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
async for token in llm.astream("Who's the best quarterback in the NFL?"):
assert isinstance(token, str)
embeddings = OpenAIEmbeddings()
db = DeepLake(
dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True
)
db.add_documents(docs)
@pytest.mark.asyncio
async def test_fireworks_async_agenerate() -> None:
"""Test async."""
llm = Fireworks()
output = await llm.agenerate(["What is the best city to live in California?"])
assert isinstance(output, LLMResult)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search(query)
qa = RetrievalQA.from_chain_type(
llm=FireworksChat(),
chain_type="stuff",
retriever=db.as_retriever(),
@pytest.mark.asyncio
async def test_fireworks_multiple_prompts_async_agenerate() -> None:
llm = Fireworks()
output = await llm.agenerate(
["How is the weather in New York today?", "I'm pickle rick"]
)
query = "What did the president say about Ketanji Brown Jackson"
output = qa.run(query)
assert isinstance(output, str)
_EXPECTED_NUM_TOKENS = {
"accounts/fireworks/models/fireworks-llama-v2-13b": 17,
"accounts/fireworks/models/fireworks-llama-v2-7b": 17,
"accounts/fireworks/models/fireworks-llama-v2-13b-chat": 17,
"accounts/fireworks/models/fireworks-llama-v2-7b-chat": 17,
}
_MODELS = models = [
"accounts/fireworks/models/fireworks-llama-v2-13b",
"accounts/fireworks/models/fireworks-llama-v2-7b",
"accounts/fireworks/models/fireworks-llama-v2-13b-chat",
"accounts/fireworks/models/fireworks-llama-v2-7b-chat",
]
@pytest.mark.parametrize("model", _MODELS)
def test_fireworks_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = Fireworks(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2

907
poetry.lock generated

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save