2024-02-23 20:45:47 +00:00
""" Fireworks chat wrapper. """
from __future__ import annotations
import logging
import os
2024-02-26 20:46:39 +00:00
from operator import itemgetter
2024-02-23 20:45:47 +00:00
from typing import (
Any ,
2024-03-01 17:20:26 +00:00
AsyncIterator ,
2024-02-23 20:45:47 +00:00
Callable ,
Dict ,
2024-03-01 17:20:26 +00:00
Iterator ,
2024-02-23 20:45:47 +00:00
List ,
Literal ,
Mapping ,
Optional ,
Sequence ,
Tuple ,
Type ,
TypedDict ,
Union ,
cast ,
)
from fireworks . client import AsyncFireworks , Fireworks # type: ignore
from langchain_core . callbacks import (
2024-03-01 17:20:26 +00:00
AsyncCallbackManagerForLLMRun ,
2024-02-23 20:45:47 +00:00
CallbackManagerForLLMRun ,
)
from langchain_core . language_models import LanguageModelInput
from langchain_core . language_models . chat_models import (
BaseChatModel ,
2024-03-01 17:20:26 +00:00
agenerate_from_stream ,
2024-02-23 20:45:47 +00:00
generate_from_stream ,
)
from langchain_core . messages import (
AIMessage ,
AIMessageChunk ,
BaseMessage ,
BaseMessageChunk ,
ChatMessage ,
ChatMessageChunk ,
FunctionMessage ,
FunctionMessageChunk ,
HumanMessage ,
HumanMessageChunk ,
SystemMessage ,
SystemMessageChunk ,
ToolMessage ,
ToolMessageChunk ,
)
2024-02-26 20:46:39 +00:00
from langchain_core . output_parsers import JsonOutputParser , PydanticOutputParser
from langchain_core . output_parsers . base import OutputParserLike
from langchain_core . output_parsers . openai_tools import (
JsonOutputKeyToolsParser ,
PydanticToolsParser ,
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
make_invalid_tool_call ,
parse_tool_call ,
2024-02-26 20:46:39 +00:00
)
2024-03-01 17:20:26 +00:00
from langchain_core . outputs import ChatGeneration , ChatGenerationChunk , ChatResult
2024-02-23 20:45:47 +00:00
from langchain_core . pydantic_v1 import BaseModel , Field , SecretStr , root_validator
2024-02-26 20:46:39 +00:00
from langchain_core . runnables import Runnable , RunnableMap , RunnablePassthrough
2024-02-23 20:45:47 +00:00
from langchain_core . tools import BaseTool
from langchain_core . utils import (
convert_to_secret_str ,
get_from_dict_or_env ,
get_pydantic_field_names ,
)
from langchain_core . utils . function_calling import (
convert_to_openai_function ,
convert_to_openai_tool ,
)
from langchain_core . utils . utils import build_extra_kwargs
logger = logging . getLogger ( __name__ )
def _convert_dict_to_message ( _dict : Mapping [ str , Any ] ) - > BaseMessage :
""" Convert a dictionary to a LangChain message.
Args :
_dict : The dictionary .
Returns :
The LangChain message .
"""
role = _dict . get ( " role " )
if role == " user " :
return HumanMessage ( content = _dict . get ( " content " , " " ) )
elif role == " assistant " :
# Fix for azure
# Also Fireworks returns None for tool invocations
content = _dict . get ( " content " , " " ) or " "
additional_kwargs : Dict = { }
if function_call := _dict . get ( " function_call " ) :
additional_kwargs [ " function_call " ] = dict ( function_call )
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
tool_calls = [ ]
invalid_tool_calls = [ ]
if raw_tool_calls := _dict . get ( " tool_calls " ) :
additional_kwargs [ " tool_calls " ] = raw_tool_calls
for raw_tool_call in raw_tool_calls :
try :
tool_calls . append ( parse_tool_call ( raw_tool_call , return_id = True ) )
except Exception as e :
invalid_tool_calls . append (
dict ( make_invalid_tool_call ( raw_tool_call , str ( e ) ) )
)
return AIMessage (
content = content ,
additional_kwargs = additional_kwargs ,
tool_calls = tool_calls ,
invalid_tool_calls = invalid_tool_calls ,
)
2024-02-23 20:45:47 +00:00
elif role == " system " :
return SystemMessage ( content = _dict . get ( " content " , " " ) )
elif role == " function " :
return FunctionMessage (
content = _dict . get ( " content " , " " ) , name = _dict . get ( " name " , " " )
)
elif role == " tool " :
additional_kwargs = { }
if " name " in _dict :
additional_kwargs [ " name " ] = _dict [ " name " ]
return ToolMessage (
content = _dict . get ( " content " , " " ) ,
tool_call_id = _dict . get ( " tool_call_id " , " " ) ,
additional_kwargs = additional_kwargs ,
)
else :
return ChatMessage ( content = _dict . get ( " content " , " " ) , role = role or " " )
def _convert_message_to_dict ( message : BaseMessage ) - > dict :
""" Convert a LangChain message to a dictionary.
Args :
message : The LangChain message .
Returns :
The dictionary .
"""
message_dict : Dict [ str , Any ]
if isinstance ( message , ChatMessage ) :
message_dict = { " role " : message . role , " content " : message . content }
elif isinstance ( message , HumanMessage ) :
message_dict = { " role " : " user " , " content " : message . content }
elif isinstance ( message , AIMessage ) :
message_dict = { " role " : " assistant " , " content " : message . content }
if " function_call " in message . additional_kwargs :
message_dict [ " function_call " ] = message . additional_kwargs [ " function_call " ]
# If function call only, content is None not empty string
if message_dict [ " content " ] == " " :
message_dict [ " content " ] = None
if " tool_calls " in message . additional_kwargs :
message_dict [ " tool_calls " ] = message . additional_kwargs [ " tool_calls " ]
# If tool calls only, content is None not empty string
if message_dict [ " content " ] == " " :
message_dict [ " content " ] = None
elif isinstance ( message , SystemMessage ) :
message_dict = { " role " : " system " , " content " : message . content }
elif isinstance ( message , FunctionMessage ) :
message_dict = {
" role " : " function " ,
" content " : message . content ,
" name " : message . name ,
}
elif isinstance ( message , ToolMessage ) :
message_dict = {
" role " : " tool " ,
" content " : message . content ,
" tool_call_id " : message . tool_call_id ,
}
else :
raise TypeError ( f " Got unknown type { message } " )
if " name " in message . additional_kwargs :
message_dict [ " name " ] = message . additional_kwargs [ " name " ]
return message_dict
def _convert_delta_to_message_chunk (
_dict : Mapping [ str , Any ] , default_class : Type [ BaseMessageChunk ]
) - > BaseMessageChunk :
role = cast ( str , _dict . get ( " role " ) )
content = cast ( str , _dict . get ( " content " ) or " " )
additional_kwargs : Dict = { }
if _dict . get ( " function_call " ) :
function_call = dict ( _dict [ " function_call " ] )
if " name " in function_call and function_call [ " name " ] is None :
function_call [ " name " ] = " "
additional_kwargs [ " function_call " ] = function_call
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
if raw_tool_calls := _dict . get ( " tool_calls " ) :
additional_kwargs [ " tool_calls " ] = raw_tool_calls
try :
tool_call_chunks = [
{
" name " : rtc [ " function " ] . get ( " name " ) ,
" args " : rtc [ " function " ] . get ( " arguments " ) ,
" id " : rtc . get ( " id " ) ,
" index " : rtc [ " index " ] ,
}
for rtc in raw_tool_calls
]
except KeyError :
pass
else :
tool_call_chunks = [ ]
2024-02-23 20:45:47 +00:00
if role == " user " or default_class == HumanMessageChunk :
return HumanMessageChunk ( content = content )
elif role == " assistant " or default_class == AIMessageChunk :
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
return AIMessageChunk (
content = content ,
additional_kwargs = additional_kwargs ,
tool_call_chunks = tool_call_chunks ,
)
2024-02-23 20:45:47 +00:00
elif role == " system " or default_class == SystemMessageChunk :
return SystemMessageChunk ( content = content )
elif role == " function " or default_class == FunctionMessageChunk :
return FunctionMessageChunk ( content = content , name = _dict [ " name " ] )
elif role == " tool " or default_class == ToolMessageChunk :
return ToolMessageChunk ( content = content , tool_call_id = _dict [ " tool_call_id " ] )
elif role or default_class == ChatMessageChunk :
return ChatMessageChunk ( content = content , role = role )
else :
return default_class ( content = content ) # type: ignore
class _FunctionCall ( TypedDict ) :
name : str
2024-02-26 20:46:39 +00:00
# This is basically a copy and replace for ChatFireworks, except
2024-02-23 20:45:47 +00:00
# - I needed to gut out tiktoken and some of the token estimation logic
# (not sure how important it is)
# - Environment variable is different
# we should refactor into some OpenAI-like class in the future
class ChatFireworks ( BaseChatModel ) :
""" `Fireworks` Chat large language models API.
To use , you should have the
environment variable ` ` FIREWORKS_API_KEY ` ` set with your API key .
Any parameters that are valid to be passed to the fireworks . create call
can be passed in , even if not explicitly saved on this class .
Example :
. . code - block : : python
from langchain_fireworks . chat_models import ChatFireworks
fireworks = ChatFireworks (
model_name = " accounts/fireworks/models/mixtral-8x7b-instruct " )
"""
@property
def lc_secrets ( self ) - > Dict [ str , str ] :
return { " fireworks_api_key " : " FIREWORKS_API_KEY " }
@classmethod
def get_lc_namespace ( cls ) - > List [ str ] :
""" Get the namespace of the langchain object. """
return [ " langchain " , " chat_models " , " fireworks " ]
@property
def lc_attributes ( self ) - > Dict [ str , Any ] :
attributes : Dict [ str , Any ] = { }
if self . fireworks_api_base :
attributes [ " fireworks_api_base " ] = self . fireworks_api_base
return attributes
@classmethod
def is_lc_serializable ( cls ) - > bool :
""" Return whether this model can be serialized by Langchain. """
return True
client : Any = Field ( default = None , exclude = True ) #: :meta private:
async_client : Any = Field ( default = None , exclude = True ) #: :meta private:
model_name : str = Field (
default = " accounts/fireworks/models/mixtral-8x7b-instruct " , alias = " model "
)
""" Model name to use. """
temperature : float = 0.0
""" What sampling temperature to use. """
model_kwargs : Dict [ str , Any ] = Field ( default_factory = dict )
""" Holds any model parameters valid for `create` call not explicitly specified. """
fireworks_api_key : SecretStr = Field ( default = None , alias = " api_key " )
""" Automatically inferred from env var `FIREWORKS_API_KEY` if not provided. """
fireworks_api_base : Optional [ str ] = Field ( default = None , alias = " base_url " )
""" Base URL path for API requests, leave blank if not using a proxy or service
emulator . """
request_timeout : Union [ float , Tuple [ float , float ] , Any , None ] = Field (
default = None , alias = " timeout "
)
""" Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
None . """
streaming : bool = False
""" Whether to stream the results or not. """
n : int = 1
""" Number of chat completions to generate for each prompt. """
max_tokens : Optional [ int ] = None
""" Maximum number of tokens to generate. """
class Config :
""" Configuration for this pydantic object. """
allow_population_by_field_name = True
@root_validator ( pre = True )
def build_extra ( cls , values : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Build extra kwargs from additional params that were passed in. """
all_required_field_names = get_pydantic_field_names ( cls )
extra = values . get ( " model_kwargs " , { } )
values [ " model_kwargs " ] = build_extra_kwargs (
extra , values , all_required_field_names
)
return values
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key and python package exists in environment. """
if values [ " n " ] < 1 :
raise ValueError ( " n must be at least 1. " )
if values [ " n " ] > 1 and values [ " streaming " ] :
raise ValueError ( " n must be 1 when streaming. " )
values [ " fireworks_api_key " ] = convert_to_secret_str (
get_from_dict_or_env ( values , " fireworks_api_key " , " FIREWORKS_API_KEY " )
)
values [ " fireworks_api_base " ] = values [ " fireworks_api_base " ] or os . getenv (
" FIREWORKS_API_BASE "
)
client_params = {
" api_key " : (
values [ " fireworks_api_key " ] . get_secret_value ( )
if values [ " fireworks_api_key " ]
else None
) ,
" base_url " : values [ " fireworks_api_base " ] ,
" timeout " : values [ " request_timeout " ] ,
}
if not values . get ( " client " ) :
values [ " client " ] = Fireworks ( * * client_params ) . chat . completions
if not values . get ( " async_client " ) :
values [ " async_client " ] = AsyncFireworks ( * * client_params ) . chat . completions
return values
@property
def _default_params ( self ) - > Dict [ str , Any ] :
""" Get the default parameters for calling Fireworks API. """
params = {
" model " : self . model_name ,
" stream " : self . streaming ,
" n " : self . n ,
" temperature " : self . temperature ,
* * self . model_kwargs ,
}
if self . max_tokens is not None :
params [ " max_tokens " ] = self . max_tokens
return params
def _combine_llm_outputs ( self , llm_outputs : List [ Optional [ dict ] ] ) - > dict :
overall_token_usage : dict = { }
system_fingerprint = None
for output in llm_outputs :
if output is None :
# Happens in streaming
continue
token_usage = output [ " token_usage " ]
if token_usage is not None :
for k , v in token_usage . items ( ) :
if k in overall_token_usage :
overall_token_usage [ k ] + = v
else :
overall_token_usage [ k ] = v
if system_fingerprint is None :
system_fingerprint = output . get ( " system_fingerprint " )
combined = { " token_usage " : overall_token_usage , " model_name " : self . model_name }
if system_fingerprint :
combined [ " system_fingerprint " ] = system_fingerprint
return combined
2024-03-01 17:20:26 +00:00
def _stream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ ChatGenerationChunk ] :
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = { * * params , * * kwargs , " stream " : True }
default_chunk_class = AIMessageChunk
for chunk in self . client . create ( messages = message_dicts , * * params ) :
if not isinstance ( chunk , dict ) :
chunk = chunk . dict ( )
if len ( chunk [ " choices " ] ) == 0 :
continue
choice = chunk [ " choices " ] [ 0 ]
chunk = _convert_delta_to_message_chunk (
choice [ " delta " ] , default_chunk_class
)
generation_info = { }
if finish_reason := choice . get ( " finish_reason " ) :
generation_info [ " finish_reason " ] = finish_reason
logprobs = choice . get ( " logprobs " )
if logprobs :
generation_info [ " logprobs " ] = logprobs
default_chunk_class = chunk . __class__
chunk = ChatGenerationChunk (
message = chunk , generation_info = generation_info or None
)
if run_manager :
run_manager . on_llm_new_token ( chunk . text , chunk = chunk , logprobs = logprobs )
yield chunk
2024-02-23 20:45:47 +00:00
def _generate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
stream : Optional [ bool ] = None ,
* * kwargs : Any ,
) - > ChatResult :
should_stream = stream if stream is not None else self . streaming
if should_stream :
stream_iter = self . _stream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return generate_from_stream ( stream_iter )
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = {
* * params ,
* * ( { " stream " : stream } if stream is not None else { } ) ,
* * kwargs ,
}
response = self . client . create ( messages = message_dicts , * * params )
return self . _create_chat_result ( response )
def _create_message_dicts (
self , messages : List [ BaseMessage ] , stop : Optional [ List [ str ] ]
) - > Tuple [ List [ Dict [ str , Any ] ] , Dict [ str , Any ] ] :
params = self . _default_params
if stop is not None :
if " stop " in params :
raise ValueError ( " `stop` found in both the input and default params. " )
params [ " stop " ] = stop
message_dicts = [ _convert_message_to_dict ( m ) for m in messages ]
return message_dicts , params
def _create_chat_result ( self , response : Union [ dict , BaseModel ] ) - > ChatResult :
generations = [ ]
if not isinstance ( response , dict ) :
response = response . dict ( )
for res in response [ " choices " ] :
message = _convert_dict_to_message ( res [ " message " ] )
generation_info = dict ( finish_reason = res . get ( " finish_reason " ) )
if " logprobs " in res :
generation_info [ " logprobs " ] = res [ " logprobs " ]
gen = ChatGeneration (
message = message ,
generation_info = generation_info ,
)
generations . append ( gen )
token_usage = response . get ( " usage " , { } )
llm_output = {
" token_usage " : token_usage ,
" model_name " : self . model_name ,
" system_fingerprint " : response . get ( " system_fingerprint " , " " ) ,
}
return ChatResult ( generations = generations , llm_output = llm_output )
2024-03-01 17:20:26 +00:00
async def _astream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ ChatGenerationChunk ] :
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = { * * params , * * kwargs , " stream " : True }
default_chunk_class = AIMessageChunk
async for chunk in self . async_client . acreate ( messages = message_dicts , * * params ) :
if not isinstance ( chunk , dict ) :
chunk = chunk . dict ( )
if len ( chunk [ " choices " ] ) == 0 :
continue
choice = chunk [ " choices " ] [ 0 ]
chunk = _convert_delta_to_message_chunk (
choice [ " delta " ] , default_chunk_class
)
generation_info = { }
if finish_reason := choice . get ( " finish_reason " ) :
generation_info [ " finish_reason " ] = finish_reason
logprobs = choice . get ( " logprobs " )
if logprobs :
generation_info [ " logprobs " ] = logprobs
default_chunk_class = chunk . __class__
chunk = ChatGenerationChunk (
message = chunk , generation_info = generation_info or None
)
if run_manager :
await run_manager . on_llm_new_token (
token = chunk . text , chunk = chunk , logprobs = logprobs
)
yield chunk
async def _agenerate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
stream : Optional [ bool ] = None ,
* * kwargs : Any ,
) - > ChatResult :
should_stream = stream if stream is not None else self . streaming
if should_stream :
stream_iter = self . _astream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return await agenerate_from_stream ( stream_iter )
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = {
* * params ,
* * ( { " stream " : stream } if stream is not None else { } ) ,
* * kwargs ,
}
response = await self . async_client . acreate ( messages = message_dicts , * * params )
return self . _create_chat_result ( response )
2024-02-23 20:45:47 +00:00
@property
def _identifying_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
return { " model_name " : self . model_name , * * self . _default_params }
def _get_invocation_params (
self , stop : Optional [ List [ str ] ] = None , * * kwargs : Any
) - > Dict [ str , Any ] :
""" Get the parameters used to invoke the model. """
return {
" model " : self . model_name ,
* * super ( ) . _get_invocation_params ( stop = stop ) ,
* * self . _default_params ,
* * kwargs ,
}
@property
def _llm_type ( self ) - > str :
""" Return type of chat model. """
return " fireworks-chat "
def bind_functions (
self ,
functions : Sequence [ Union [ Dict [ str , Any ] , Type [ BaseModel ] , Callable , BaseTool ] ] ,
function_call : Optional [
Union [ _FunctionCall , str , Literal [ " auto " , " none " ] ]
] = None ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , BaseMessage ] :
""" Bind functions (and other objects) to this chat model.
Assumes model is compatible with Fireworks function - calling API .
NOTE : Using bind_tools is recommended instead , as the ` functions ` and
` function_call ` request parameters are officially marked as deprecated by
Fireworks .
Args :
functions : A list of function definitions to bind to this chat model .
Can be a dictionary , pydantic model , or callable . Pydantic
models and callables will be automatically converted to
their schema dictionary representation .
function_call : Which function to require the model to call .
Must be the name of the single provided function or
" auto " to automatically determine which function to call
( if any ) .
* * kwargs : Any additional parameters to pass to the
: class : ` ~ langchain . runnable . Runnable ` constructor .
"""
formatted_functions = [ convert_to_openai_function ( fn ) for fn in functions ]
if function_call is not None :
function_call = (
{ " name " : function_call }
if isinstance ( function_call , str )
and function_call not in ( " auto " , " none " )
else function_call
)
if isinstance ( function_call , dict ) and len ( formatted_functions ) != 1 :
raise ValueError (
" When specifying `function_call`, you must provide exactly one "
" function. "
)
if (
isinstance ( function_call , dict )
and formatted_functions [ 0 ] [ " name " ] != function_call [ " name " ]
) :
raise ValueError (
f " Function call { function_call } was specified, but the only "
f " provided function was { formatted_functions [ 0 ] [ ' name ' ] } . "
)
kwargs = { * * kwargs , " function_call " : function_call }
return super ( ) . bind (
functions = formatted_functions ,
* * kwargs ,
)
def bind_tools (
self ,
tools : Sequence [ Union [ Dict [ str , Any ] , Type [ BaseModel ] , Callable , BaseTool ] ] ,
* ,
2024-03-01 19:12:28 +00:00
tool_choice : Optional [
Union [ dict , str , Literal [ " auto " , " any " , " none " ] , bool ]
] = None ,
2024-02-23 20:45:47 +00:00
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , BaseMessage ] :
""" Bind tool-like objects to this chat model.
Assumes model is compatible with Fireworks tool - calling API .
Args :
tools : A list of tool definitions to bind to this chat model .
Can be a dictionary , pydantic model , callable , or BaseTool . Pydantic
models , callables , and BaseTools will be automatically converted to
their schema dictionary representation .
tool_choice : Which tool to require the model to call .
2024-03-01 19:12:28 +00:00
Must be the name of the single provided function ,
2024-02-23 20:45:47 +00:00
" auto " to automatically determine which function to call
2024-03-01 19:12:28 +00:00
with the option to not call any function , " any " to enforce that some
function is called , or a dict of the form :
2024-02-23 20:45:47 +00:00
{ " type " : " function " , " function " : { " name " : << tool_name >> } } .
* * kwargs : Any additional parameters to pass to the
: class : ` ~ langchain . runnable . Runnable ` constructor .
"""
formatted_tools = [ convert_to_openai_tool ( tool ) for tool in tools ]
2024-02-26 20:46:39 +00:00
if tool_choice is not None and tool_choice :
2024-03-01 19:12:28 +00:00
if isinstance ( tool_choice , str ) and (
tool_choice not in ( " auto " , " any " , " none " )
) :
2024-02-23 20:45:47 +00:00
tool_choice = { " type " : " function " , " function " : { " name " : tool_choice } }
if isinstance ( tool_choice , dict ) and ( len ( formatted_tools ) != 1 ) :
raise ValueError (
" When specifying `tool_choice`, you must provide exactly one "
f " tool. Received { len ( formatted_tools ) } tools. "
)
if isinstance ( tool_choice , dict ) and (
formatted_tools [ 0 ] [ " function " ] [ " name " ]
!= tool_choice [ " function " ] [ " name " ]
) :
raise ValueError (
f " Tool choice { tool_choice } was specified, but the only "
f " provided tool was { formatted_tools [ 0 ] [ ' function ' ] [ ' name ' ] } . "
)
2024-02-26 20:46:39 +00:00
if isinstance ( tool_choice , bool ) :
if len ( tools ) > 1 :
raise ValueError (
" tool_choice can only be True when there is one tool. Received "
f " { len ( tools ) } tools. "
)
2024-03-01 01:18:15 +00:00
tool_name = formatted_tools [ 0 ] [ " function " ] [ " name " ]
tool_choice = {
" type " : " function " ,
" function " : { " name " : tool_name } ,
}
2024-02-23 20:45:47 +00:00
kwargs [ " tool_choice " ] = tool_choice
return super ( ) . bind ( tools = formatted_tools , * * kwargs )
2024-02-26 20:46:39 +00:00
def with_structured_output (
self ,
schema : Optional [ Union [ Dict , Type [ BaseModel ] ] ] = None ,
* ,
method : Literal [ " function_calling " , " json_mode " ] = " function_calling " ,
include_raw : bool = False ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , Union [ Dict , BaseModel ] ] :
""" Model wrapper that returns outputs formatted to match the given schema.
Args :
schema : The output schema as a dict or a Pydantic class . If a Pydantic class
then the model output will be an object of that class . If a dict then
the model output will be a dict . With a Pydantic class the returned
attributes will be validated , whereas with a dict they will not be . If
` method ` is " function_calling " and ` schema ` is a dict , then the dict
must match the Fireworks function - calling spec .
method : The method for steering model generation , either " function_calling "
or " json_mode " . If " function_calling " then the schema will be converted
to a Fireworks function and the returned model will make use of the
function - calling API . If " json_mode " then Fireworks ' s JSON mode will be
used . Note that if using " json_mode " then you must include instructions
for formatting the output into the desired schema into the model call .
include_raw : If False then only the parsed structured output is returned . If
an error occurs during model output parsing it will be raised . If True
then both the raw model response ( a BaseMessage ) and the parsed model
response will be returned . If an error occurs during output parsing it
will be caught and returned as well . The final output is always a dict
with keys " raw " , " parsed " , and " parsing_error " .
Returns :
A Runnable that takes any ChatModel input and returns as output :
If include_raw is True then a dict with keys :
raw : BaseMessage
parsed : Optional [ _DictOrPydantic ]
parsing_error : Optional [ BaseException ]
If include_raw is False then just _DictOrPydantic is returned ,
where _DictOrPydantic depends on the schema :
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class .
If schema is a dict then _DictOrPydantic is a dict .
Example : Function - calling , Pydantic schema ( method = " function_calling " , include_raw = False ) :
. . code - block : : python
from langchain_fireworks import ChatFireworks
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
llm = ChatFireworks ( model = " accounts/fireworks/models/firefunction-v1 " , temperature = 0 )
structured_llm = llm . with_structured_output ( AnswerWithJustification )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example : Function - calling , Pydantic schema ( method = " function_calling " , include_raw = True ) :
. . code - block : : python
from langchain_fireworks import ChatFireworks
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
llm = ChatFireworks ( model = " accounts/fireworks/models/firefunction-v1 " , temperature = 0 )
structured_llm = llm . with_structured_output ( AnswerWithJustification , include_raw = True )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example : Function - calling , dict schema ( method = " function_calling " , include_raw = False ) :
. . code - block : : python
from langchain_fireworks import ChatFireworks
from langchain_core . pydantic_v1 import BaseModel
from langchain_core . utils . function_calling import convert_to_openai_tool
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
dict_schema = convert_to_openai_tool ( AnswerWithJustification )
llm = ChatFireworks ( model = " accounts/fireworks/models/firefunction-v1 " , temperature = 0 )
structured_llm = llm . with_structured_output ( dict_schema )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
Example : JSON mode , Pydantic schema ( method = " json_mode " , include_raw = True ) :
. . code - block : :
from langchain_fireworks import ChatFireworks
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
answer : str
justification : str
llm = ChatFireworks ( model = " accounts/fireworks/models/firefunction-v1 " , temperature = 0 )
structured_llm = llm . with_structured_output (
AnswerWithJustification ,
method = " json_mode " ,
include_raw = True
)
structured_llm . invoke (
" Answer the following question. "
" Make sure to return a JSON blob with keys ' answer ' and ' justification ' . \n \n "
" What ' s heavier a pound of bricks or a pound of feathers? "
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None
# }
Example : JSON mode , no schema ( schema = None , method = " json_mode " , include_raw = True ) :
. . code - block : :
from langchain_fireworks import ChatFireworks
llm = ChatFireworks ( model = " accounts/fireworks/models/firefunction-v1 " , temperature = 0 )
structured_llm = llm . with_structured_output ( method = " json_mode " , include_raw = True )
structured_llm . invoke (
" Answer the following question. "
" Make sure to return a JSON blob with keys ' answer ' and ' justification ' . \n \n "
" What ' s heavier a pound of bricks or a pound of feathers? "
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
# 'parsed': {
# 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
# },
# 'parsing_error': None
# }
""" # noqa: E501
if kwargs :
raise ValueError ( f " Received unsupported arguments { kwargs } " )
is_pydantic_schema = _is_pydantic_class ( schema )
if method == " function_calling " :
if schema is None :
raise ValueError (
" schema must be specified when method is ' function_calling ' . "
" Received None. "
)
llm = self . bind_tools ( [ schema ] , tool_choice = True )
if is_pydantic_schema :
output_parser : OutputParserLike = PydanticToolsParser (
tools = [ schema ] , first_tool_only = True
)
else :
key_name = convert_to_openai_tool ( schema ) [ " function " ] [ " name " ]
output_parser = JsonOutputKeyToolsParser (
key_name = key_name , first_tool_only = True
)
elif method == " json_mode " :
llm = self . bind ( response_format = { " type " : " json_object " } )
output_parser = (
PydanticOutputParser ( pydantic_object = schema )
if is_pydantic_schema
else JsonOutputParser ( )
)
else :
raise ValueError (
f " Unrecognized method argument. Expected one of ' function_calling ' or "
f " ' json_format ' . Received: ' { method } ' "
)
if include_raw :
parser_assign = RunnablePassthrough . assign (
parsed = itemgetter ( " raw " ) | output_parser , parsing_error = lambda _ : None
)
parser_none = RunnablePassthrough . assign ( parsed = lambda _ : None )
parser_with_fallback = parser_assign . with_fallbacks (
[ parser_none ] , exception_key = " parsing_error "
)
return RunnableMap ( raw = llm ) | parser_with_fallback
else :
return llm | output_parser
def _is_pydantic_class ( obj : Any ) - > bool :
return isinstance ( obj , type ) and issubclass ( obj , BaseModel )