2024-02-22 15:36:16 +00:00
""" Groq Chat wrapper. """
from __future__ import annotations
import os
import warnings
2024-04-03 21:40:20 +00:00
from operator import itemgetter
2024-02-22 15:36:16 +00:00
from typing import (
Any ,
AsyncIterator ,
2024-04-03 21:40:20 +00:00
Callable ,
2024-02-22 15:36:16 +00:00
Dict ,
Iterator ,
List ,
2024-04-03 21:40:20 +00:00
Literal ,
2024-02-22 15:36:16 +00:00
Mapping ,
Optional ,
2024-04-03 21:40:20 +00:00
Sequence ,
2024-02-22 15:36:16 +00:00
Tuple ,
Type ,
2024-04-03 21:40:20 +00:00
TypedDict ,
2024-02-22 15:36:16 +00:00
Union ,
cast ,
)
2024-04-03 21:40:20 +00:00
from langchain_core . _api import beta
2024-02-22 15:36:16 +00:00
from langchain_core . callbacks import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
2024-04-03 21:40:20 +00:00
from langchain_core . language_models import LanguageModelInput
2024-02-22 15:36:16 +00:00
from langchain_core . language_models . chat_models import (
BaseChatModel ,
agenerate_from_stream ,
generate_from_stream ,
)
from langchain_core . messages import (
AIMessage ,
AIMessageChunk ,
BaseMessage ,
BaseMessageChunk ,
ChatMessage ,
ChatMessageChunk ,
FunctionMessage ,
FunctionMessageChunk ,
HumanMessage ,
HumanMessageChunk ,
SystemMessage ,
SystemMessageChunk ,
ToolMessage ,
ToolMessageChunk ,
)
2024-04-03 21:40:20 +00:00
from langchain_core . output_parsers import (
JsonOutputParser ,
PydanticOutputParser ,
)
from langchain_core . output_parsers . base import OutputParserLike
from langchain_core . output_parsers . openai_tools import (
JsonOutputKeyToolsParser ,
PydanticToolsParser ,
)
2024-02-22 15:36:16 +00:00
from langchain_core . outputs import ChatGeneration , ChatGenerationChunk , ChatResult
from langchain_core . pydantic_v1 import BaseModel , Field , SecretStr , root_validator
2024-04-03 21:40:20 +00:00
from langchain_core . runnables import Runnable , RunnableMap , RunnablePassthrough
from langchain_core . tools import BaseTool
2024-02-22 15:36:16 +00:00
from langchain_core . utils import (
convert_to_secret_str ,
get_from_dict_or_env ,
get_pydantic_field_names ,
)
2024-04-03 21:40:20 +00:00
from langchain_core . utils . function_calling import (
convert_to_openai_function ,
convert_to_openai_tool ,
)
2024-02-22 15:36:16 +00:00
class ChatGroq ( BaseChatModel ) :
""" `Groq` Chat large language models API.
To use , you should have the
environment variable ` ` GROQ_API_KEY ` ` set with your API key .
Any parameters that are valid to be passed to the groq . create call can be passed
in , even if not explicitly saved on this class .
Example :
. . code - block : : python
2024-03-28 06:46:52 +00:00
from langchain_groq import ChatGroq
model = ChatGroq ( model_name = " mixtral-8x7b-32768 " )
2024-02-22 15:36:16 +00:00
"""
client : Any = Field ( default = None , exclude = True ) #: :meta private:
async_client : Any = Field ( default = None , exclude = True ) #: :meta private:
model_name : str = Field ( default = " mixtral-8x7b-32768 " , alias = " model " )
""" Model name to use. """
temperature : float = 0.7
""" What sampling temperature to use. """
model_kwargs : Dict [ str , Any ] = Field ( default_factory = dict )
""" Holds any model parameters valid for `create` call not explicitly specified. """
groq_api_key : Optional [ SecretStr ] = Field ( default = None , alias = " api_key " )
""" Automatically inferred from env var `groq_API_KEY` if not provided. """
groq_api_base : Optional [ str ] = Field ( default = None , alias = " base_url " )
""" Base URL path for API requests, leave blank if not using a proxy or service
emulator . """
# to support explicit proxy for Groq
groq_proxy : Optional [ str ] = None
request_timeout : Union [ float , Tuple [ float , float ] , Any , None ] = Field (
default = None , alias = " timeout "
)
""" Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
None . """
max_retries : int = 2
""" Maximum number of retries to make when generating. """
streaming : bool = False
""" Whether to stream the results or not. """
n : int = 1
""" Number of chat completions to generate for each prompt. """
max_tokens : Optional [ int ] = None
""" Maximum number of tokens to generate. """
default_headers : Union [ Mapping [ str , str ] , None ] = None
default_query : Union [ Mapping [ str , object ] , None ] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client : Union [ Any , None ] = None
""" Optional httpx.Client. """
class Config :
""" Configuration for this pydantic object. """
allow_population_by_field_name = True
@root_validator ( pre = True )
def build_extra ( cls , values : Dict [ str , Any ] ) - > Dict [ str , Any ] :
""" Build extra kwargs from additional params that were passed in. """
all_required_field_names = get_pydantic_field_names ( cls )
extra = values . get ( " model_kwargs " , { } )
for field_name in list ( values ) :
if field_name in extra :
raise ValueError ( f " Found { field_name } supplied twice. " )
if field_name not in all_required_field_names :
warnings . warn (
f """ WARNING! { field_name } is not default parameter.
{ field_name } was transferred to model_kwargs .
Please confirm that { field_name } is what you intended . """
)
extra [ field_name ] = values . pop ( field_name )
invalid_model_kwargs = all_required_field_names . intersection ( extra . keys ( ) )
if invalid_model_kwargs :
raise ValueError (
f " Parameters { invalid_model_kwargs } should be specified explicitly. "
f " Instead they were passed in as part of `model_kwargs` parameter. "
)
values [ " model_kwargs " ] = extra
return values
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key and python package exists in environment. """
if values [ " n " ] < 1 :
raise ValueError ( " n must be at least 1. " )
if values [ " n " ] > 1 and values [ " streaming " ] :
raise ValueError ( " n must be 1 when streaming. " )
if values [ " temperature " ] == 0 :
values [ " temperature " ] = 1e-8
values [ " groq_api_key " ] = convert_to_secret_str (
get_from_dict_or_env ( values , " groq_api_key " , " GROQ_API_KEY " )
)
values [ " groq_api_base " ] = values [ " groq_api_base " ] or os . getenv ( " GROQ_API_BASE " )
values [ " groq_proxy " ] = values [ " groq_proxy " ] = os . getenv ( " GROQ_PROXY " )
client_params = {
" api_key " : values [ " groq_api_key " ] . get_secret_value ( ) ,
" base_url " : values [ " groq_api_base " ] ,
" timeout " : values [ " request_timeout " ] ,
" max_retries " : values [ " max_retries " ] ,
" default_headers " : values [ " default_headers " ] ,
" default_query " : values [ " default_query " ] ,
" http_client " : values [ " http_client " ] ,
}
try :
import groq
if not values . get ( " client " ) :
values [ " client " ] = groq . Groq ( * * client_params ) . chat . completions
if not values . get ( " async_client " ) :
values [ " async_client " ] = groq . AsyncGroq (
* * client_params
) . chat . completions
except ImportError :
raise ImportError (
" Could not import groq python package. "
" Please install it with `pip install groq`. "
)
return values
#
# Serializable class method overrides
#
@property
def lc_secrets ( self ) - > Dict [ str , str ] :
return { " groq_api_key " : " GROQ_API_KEY " }
@classmethod
def is_lc_serializable ( cls ) - > bool :
""" Return whether this model can be serialized by Langchain. """
return True
#
# BaseChatModel method overrides
#
@property
def _llm_type ( self ) - > str :
""" Return type of model. """
return " groq-chat "
def _generate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
2024-04-03 22:22:59 +00:00
if self . streaming :
2024-02-22 15:36:16 +00:00
stream_iter = self . _stream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return generate_from_stream ( stream_iter )
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = {
* * params ,
* * kwargs ,
}
response = self . client . create ( messages = message_dicts , * * params )
return self . _create_chat_result ( response )
async def _agenerate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
2024-04-03 22:22:59 +00:00
if self . streaming :
2024-02-22 15:36:16 +00:00
stream_iter = self . _astream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return await agenerate_from_stream ( stream_iter )
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = {
* * params ,
* * kwargs ,
}
response = await self . async_client . create ( messages = message_dicts , * * params )
return self . _create_chat_result ( response )
def _stream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ ChatGenerationChunk ] :
message_dicts , params = self . _create_message_dicts ( messages , stop )
2024-04-03 22:22:59 +00:00
# groq api does not support streaming with tools yet
if " tools " in kwargs :
response = self . client . create (
messages = message_dicts , * * { * * params , * * kwargs }
)
chat_result = self . _create_chat_result ( response )
generation = chat_result . generations [ 0 ]
message = generation . message
chunk_ = ChatGenerationChunk (
message = AIMessageChunk (
content = message . content , additional_kwargs = message . additional_kwargs
) ,
generation_info = generation . generation_info ,
)
if run_manager :
geninfo = chunk_ . generation_info or { }
run_manager . on_llm_new_token (
chunk_ . text ,
chunk = chunk_ ,
logprobs = geninfo . get ( " logprobs " ) ,
)
yield chunk_
return
2024-02-22 15:36:16 +00:00
params = { * * params , * * kwargs , " stream " : True }
default_chunk_class = AIMessageChunk
for chunk in self . client . create ( messages = message_dicts , * * params ) :
if not isinstance ( chunk , dict ) :
chunk = chunk . dict ( )
if len ( chunk [ " choices " ] ) == 0 :
continue
choice = chunk [ " choices " ] [ 0 ]
chunk = _convert_delta_to_message_chunk (
choice [ " delta " ] , default_chunk_class
)
generation_info = { }
if finish_reason := choice . get ( " finish_reason " ) :
generation_info [ " finish_reason " ] = finish_reason
logprobs = choice . get ( " logprobs " )
if logprobs :
generation_info [ " logprobs " ] = logprobs
default_chunk_class = chunk . __class__
chunk = ChatGenerationChunk (
message = chunk , generation_info = generation_info or None
)
2024-02-28 23:43:16 +00:00
2024-02-22 15:36:16 +00:00
if run_manager :
run_manager . on_llm_new_token ( chunk . text , chunk = chunk , logprobs = logprobs )
2024-02-28 23:43:16 +00:00
yield chunk
2024-02-22 15:36:16 +00:00
async def _astream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ ChatGenerationChunk ] :
message_dicts , params = self . _create_message_dicts ( messages , stop )
2024-04-03 22:22:59 +00:00
# groq api does not support streaming with tools yet
if " tools " in kwargs :
response = await self . async_client . create (
messages = message_dicts , * * { * * params , * * kwargs }
)
chat_result = self . _create_chat_result ( response )
generation = chat_result . generations [ 0 ]
message = generation . message
chunk_ = ChatGenerationChunk (
message = AIMessageChunk (
content = message . content , additional_kwargs = message . additional_kwargs
) ,
generation_info = generation . generation_info ,
)
if run_manager :
geninfo = chunk_ . generation_info or { }
await run_manager . on_llm_new_token (
chunk_ . text ,
chunk = chunk_ ,
logprobs = geninfo . get ( " logprobs " ) ,
)
yield chunk_
return
2024-02-22 15:36:16 +00:00
params = { * * params , * * kwargs , " stream " : True }
default_chunk_class = AIMessageChunk
async for chunk in await self . async_client . create (
messages = message_dicts , * * params
) :
if not isinstance ( chunk , dict ) :
chunk = chunk . dict ( )
if len ( chunk [ " choices " ] ) == 0 :
continue
choice = chunk [ " choices " ] [ 0 ]
chunk = _convert_delta_to_message_chunk (
choice [ " delta " ] , default_chunk_class
)
generation_info = { }
if finish_reason := choice . get ( " finish_reason " ) :
generation_info [ " finish_reason " ] = finish_reason
logprobs = choice . get ( " logprobs " )
if logprobs :
generation_info [ " logprobs " ] = logprobs
default_chunk_class = chunk . __class__
chunk = ChatGenerationChunk (
message = chunk , generation_info = generation_info or None
)
2024-02-28 23:43:16 +00:00
2024-02-22 15:36:16 +00:00
if run_manager :
await run_manager . on_llm_new_token (
token = chunk . text , chunk = chunk , logprobs = logprobs
)
2024-02-28 23:43:16 +00:00
yield chunk
2024-02-22 15:36:16 +00:00
#
# Internal methods
#
@property
def _default_params ( self ) - > Dict [ str , Any ] :
""" Get the default parameters for calling Groq API. """
params = {
" model " : self . model_name ,
" stream " : self . streaming ,
" n " : self . n ,
" temperature " : self . temperature ,
* * self . model_kwargs ,
}
if self . max_tokens is not None :
params [ " max_tokens " ] = self . max_tokens
return params
def _create_chat_result ( self , response : Union [ dict , BaseModel ] ) - > ChatResult :
generations = [ ]
if not isinstance ( response , dict ) :
response = response . dict ( )
for res in response [ " choices " ] :
message = _convert_dict_to_message ( res [ " message " ] )
generation_info = dict ( finish_reason = res . get ( " finish_reason " ) )
if " logprobs " in res :
generation_info [ " logprobs " ] = res [ " logprobs " ]
gen = ChatGeneration (
message = message ,
generation_info = generation_info ,
)
generations . append ( gen )
token_usage = response . get ( " usage " , { } )
llm_output = {
" token_usage " : token_usage ,
" model_name " : self . model_name ,
" system_fingerprint " : response . get ( " system_fingerprint " , " " ) ,
}
return ChatResult ( generations = generations , llm_output = llm_output )
def _create_message_dicts (
self , messages : List [ BaseMessage ] , stop : Optional [ List [ str ] ]
) - > Tuple [ List [ Dict [ str , Any ] ] , Dict [ str , Any ] ] :
params = self . _default_params
if stop is not None :
if " stop " in params :
raise ValueError ( " `stop` found in both the input and default params. " )
params [ " stop " ] = stop
message_dicts = [ _convert_message_to_dict ( m ) for m in messages ]
return message_dicts , params
def _combine_llm_outputs ( self , llm_outputs : List [ Optional [ dict ] ] ) - > dict :
overall_token_usage : dict = { }
system_fingerprint = None
for output in llm_outputs :
if output is None :
# Happens in streaming
continue
token_usage = output [ " token_usage " ]
if token_usage is not None :
for k , v in token_usage . items ( ) :
if k in overall_token_usage :
overall_token_usage [ k ] + = v
else :
overall_token_usage [ k ] = v
if system_fingerprint is None :
system_fingerprint = output . get ( " system_fingerprint " )
combined = { " token_usage " : overall_token_usage , " model_name " : self . model_name }
if system_fingerprint :
combined [ " system_fingerprint " ] = system_fingerprint
return combined
2024-04-03 21:40:20 +00:00
def bind_functions (
self ,
functions : Sequence [ Union [ Dict [ str , Any ] , Type [ BaseModel ] , Callable , BaseTool ] ] ,
function_call : Optional [
Union [ _FunctionCall , str , Literal [ " auto " , " none " ] ]
] = None ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , BaseMessage ] :
""" Bind functions (and other objects) to this chat model.
Model is compatible with OpenAI function - calling API .
NOTE : Using bind_tools is recommended instead , as the ` functions ` and
` function_call ` request parameters are officially deprecated .
Args :
functions : A list of function definitions to bind to this chat model .
Can be a dictionary , pydantic model , or callable . Pydantic
models and callables will be automatically converted to
their schema dictionary representation .
function_call : Which function to require the model to call .
Must be the name of the single provided function or
" auto " to automatically determine which function to call
( if any ) .
* * kwargs : Any additional parameters to pass to the
: class : ` ~ langchain . runnable . Runnable ` constructor .
"""
formatted_functions = [ convert_to_openai_function ( fn ) for fn in functions ]
if function_call is not None :
function_call = (
{ " name " : function_call }
if isinstance ( function_call , str )
and function_call not in ( " auto " , " none " )
else function_call
)
if isinstance ( function_call , dict ) and len ( formatted_functions ) != 1 :
raise ValueError (
" When specifying `function_call`, you must provide exactly one "
" function. "
)
if (
isinstance ( function_call , dict )
and formatted_functions [ 0 ] [ " name " ] != function_call [ " name " ]
) :
raise ValueError (
f " Function call { function_call } was specified, but the only "
f " provided function was { formatted_functions [ 0 ] [ ' name ' ] } . "
)
kwargs = { * * kwargs , " function_call " : function_call }
return super ( ) . bind (
functions = formatted_functions ,
* * kwargs ,
)
def bind_tools (
self ,
tools : Sequence [ Union [ Dict [ str , Any ] , Type [ BaseModel ] , Callable , BaseTool ] ] ,
* ,
tool_choice : Optional [
Union [ dict , str , Literal [ " auto " , " any " , " none " ] , bool ]
] = None ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , BaseMessage ] :
""" Bind tool-like objects to this chat model.
Args :
tools : A list of tool definitions to bind to this chat model .
Can be a dictionary , pydantic model , callable , or BaseTool . Pydantic
models , callables , and BaseTools will be automatically converted to
their schema dictionary representation .
tool_choice : Which tool to require the model to call .
Must be the name of the single provided function ,
" auto " to automatically determine which function to call
with the option to not call any function , " any " to enforce that some
function is called , or a dict of the form :
{ " type " : " function " , " function " : { " name " : << tool_name >> } } .
* * kwargs : Any additional parameters to pass to the
: class : ` ~ langchain . runnable . Runnable ` constructor .
"""
formatted_tools = [ convert_to_openai_tool ( tool ) for tool in tools ]
if tool_choice is not None and tool_choice :
if isinstance ( tool_choice , str ) and (
tool_choice not in ( " auto " , " any " , " none " )
) :
tool_choice = { " type " : " function " , " function " : { " name " : tool_choice } }
if isinstance ( tool_choice , dict ) and ( len ( formatted_tools ) != 1 ) :
raise ValueError (
" When specifying `tool_choice`, you must provide exactly one "
f " tool. Received { len ( formatted_tools ) } tools. "
)
if isinstance ( tool_choice , dict ) and (
formatted_tools [ 0 ] [ " function " ] [ " name " ]
!= tool_choice [ " function " ] [ " name " ]
) :
raise ValueError (
f " Tool choice { tool_choice } was specified, but the only "
f " provided tool was { formatted_tools [ 0 ] [ ' function ' ] [ ' name ' ] } . "
)
if isinstance ( tool_choice , bool ) :
if len ( tools ) > 1 :
raise ValueError (
" tool_choice can only be True when there is one tool. Received "
f " { len ( tools ) } tools. "
)
tool_name = formatted_tools [ 0 ] [ " function " ] [ " name " ]
tool_choice = {
" type " : " function " ,
" function " : { " name " : tool_name } ,
}
kwargs [ " tool_choice " ] = tool_choice
return super ( ) . bind ( tools = formatted_tools , * * kwargs )
@beta ( )
def with_structured_output (
self ,
schema : Optional [ Union [ Dict , Type [ BaseModel ] ] ] = None ,
* ,
method : Literal [ " function_calling " , " json_mode " ] = " function_calling " ,
include_raw : bool = False ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , Union [ Dict , BaseModel ] ] :
""" Model wrapper that returns outputs formatted to match the given schema.
Args :
schema : The output schema as a dict or a Pydantic class . If a Pydantic class
then the model output will be an object of that class . If a dict then
the model output will be a dict . With a Pydantic class the returned
attributes will be validated , whereas with a dict they will not be . If
` method ` is " function_calling " and ` schema ` is a dict , then the dict
must match the OpenAI function - calling spec .
method : The method for steering model generation , either " function_calling "
or " json_mode " . If " function_calling " then the schema will be converted
to a OpenAI function and the returned model will make use of the
function - calling API . If " json_mode " then Groq ' s JSON mode will be
used . Note that if using " json_mode " then you must include instructions
for formatting the output into the desired schema into the model call .
include_raw : If False then only the parsed structured output is returned . If
an error occurs during model output parsing it will be raised . If True
then both the raw model response ( a BaseMessage ) and the parsed model
response will be returned . If an error occurs during output parsing it
will be caught and returned as well . The final output is always a dict
with keys " raw " , " parsed " , and " parsing_error " .
Returns :
A Runnable that takes any ChatModel input and returns as output :
If include_raw is True then a dict with keys :
raw : BaseMessage
parsed : Optional [ _DictOrPydantic ]
parsing_error : Optional [ BaseException ]
If include_raw is False then just _DictOrPydantic is returned ,
where _DictOrPydantic depends on the schema :
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class .
If schema is a dict then _DictOrPydantic is a dict .
Example : Function - calling , Pydantic schema ( method = " function_calling " , include_raw = False ) :
. . code - block : : python
from langchain_groq import ChatGroq
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
llm = ChatGroq ( temperature = 0 )
structured_llm = llm . with_structured_output ( AnswerWithJustification )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> AnswerWithJustification(
# answer='A pound of bricks and a pound of feathers weigh the same.'
# justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same."
# )
Example : Function - calling , Pydantic schema ( method = " function_calling " , include_raw = True ) :
. . code - block : : python
from langchain_groq import ChatGroq
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
llm = ChatGroq ( temperature = 0 )
structured_llm = llm . with_structured_output ( AnswerWithJustification , include_raw = True )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_01htjn3cspevxbqc1d7nkk8wab', 'function': {'arguments': '{"answer": "A pound of bricks and a pound of feathers weigh the same.", "justification": "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The \'pound\' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", "unit": "pounds"}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}, id='run-456beee6-65f6-4e80-88af-a6065480822c-0'),
# 'parsed': AnswerWithJustification(answer='A pound of bricks and a pound of feathers weigh the same.', justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same."),
# 'parsing_error': None
# }
Example : Function - calling , dict schema ( method = " function_calling " , include_raw = False ) :
. . code - block : : python
from langchain_groq import ChatGroq
from langchain_core . pydantic_v1 import BaseModel
from langchain_core . utils . function_calling import convert_to_openai_tool
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
dict_schema = convert_to_openai_tool ( AnswerWithJustification )
llm = ChatGroq ( temperature = 0 )
structured_llm = llm . with_structured_output ( dict_schema )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> {
# 'answer': 'A pound of bricks and a pound of feathers weigh the same.',
# 'justification': "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", 'unit': 'pounds'}
# }
Example : JSON mode , Pydantic schema ( method = " json_mode " , include_raw = True ) :
. . code - block : :
from langchain_groq import ChatGroq
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
answer : str
justification : str
llm = ChatGroq ( temperature = 0 )
structured_llm = llm . with_structured_output (
AnswerWithJustification ,
method = " json_mode " ,
include_raw = True
)
structured_llm . invoke (
" Answer the following question. "
" Make sure to return a JSON blob with keys ' answer ' and ' justification ' . \n \n "
" What ' s heavier a pound of bricks or a pound of feathers? "
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "A pound of bricks is the same weight as a pound of feathers.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The material being weighed does not affect the weight, only the volume or number of items being weighed."\n}', id='run-e5453bc5-5025-4833-95f9-4967bf6d5c4f-0'),
# 'parsed': AnswerWithJustification(answer='A pound of bricks is the same weight as a pound of feathers.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The material being weighed does not affect the weight, only the volume or number of items being weighed.'),
# 'parsing_error': None
# }
Example : JSON mode , no schema ( schema = None , method = " json_mode " , include_raw = True ) :
. . code - block : :
from langchain_groq import ChatGroq
llm = ChatGroq ( temperature = 0 )
structured_llm = llm . with_structured_output ( method = " json_mode " , include_raw = True )
structured_llm . invoke (
" Answer the following question. "
" Make sure to return a JSON blob with keys ' answer ' and ' justification ' . \n \n "
" What ' s heavier a pound of bricks or a pound of feathers? "
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "A pound of bricks is the same weight as a pound of feathers.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The material doesn\'t change the weight, only the volume or space that the material takes up."\n}', id='run-a4abbdb6-c20e-456f-bfff-da906a7e76b5-0'),
# 'parsed': {
# 'answer': 'A pound of bricks is the same weight as a pound of feathers.',
# 'justification': "Both a pound of bricks and a pound of feathers weigh one pound. The material doesn't change the weight, only the volume or space that the material takes up."},
# 'parsing_error': None
# }
""" # noqa: E501
if kwargs :
raise ValueError ( f " Received unsupported arguments { kwargs } " )
is_pydantic_schema = _is_pydantic_class ( schema )
if method == " function_calling " :
if schema is None :
raise ValueError (
" schema must be specified when method is ' function_calling ' . "
" Received None. "
)
llm = self . bind_tools ( [ schema ] , tool_choice = True )
if is_pydantic_schema :
output_parser : OutputParserLike = PydanticToolsParser (
tools = [ schema ] , first_tool_only = True
)
else :
key_name = convert_to_openai_tool ( schema ) [ " function " ] [ " name " ]
output_parser = JsonOutputKeyToolsParser (
key_name = key_name , first_tool_only = True
)
elif method == " json_mode " :
llm = self . bind ( response_format = { " type " : " json_object " } )
output_parser = (
PydanticOutputParser ( pydantic_object = schema )
if is_pydantic_schema
else JsonOutputParser ( )
)
else :
raise ValueError (
f " Unrecognized method argument. Expected one of ' function_calling ' or "
f " ' json_format ' . Received: ' { method } ' "
)
if include_raw :
parser_assign = RunnablePassthrough . assign (
parsed = itemgetter ( " raw " ) | output_parser , parsing_error = lambda _ : None
)
parser_none = RunnablePassthrough . assign ( parsed = lambda _ : None )
parser_with_fallback = parser_assign . with_fallbacks (
[ parser_none ] , exception_key = " parsing_error "
)
return RunnableMap ( raw = llm ) | parser_with_fallback
else :
return llm | output_parser
def _is_pydantic_class ( obj : Any ) - > bool :
return isinstance ( obj , type ) and issubclass ( obj , BaseModel )
class _FunctionCall ( TypedDict ) :
name : str
2024-02-22 15:36:16 +00:00
#
# Type conversion helpers
#
def _convert_message_to_dict ( message : BaseMessage ) - > dict :
""" Convert a LangChain message to a dictionary.
Args :
message : The LangChain message .
Returns :
The dictionary .
"""
message_dict : Dict [ str , Any ]
if isinstance ( message , ChatMessage ) :
message_dict = { " role " : message . role , " content " : message . content }
elif isinstance ( message , HumanMessage ) :
message_dict = { " role " : " user " , " content " : message . content }
elif isinstance ( message , AIMessage ) :
message_dict = { " role " : " assistant " , " content " : message . content }
if " function_call " in message . additional_kwargs :
message_dict [ " function_call " ] = message . additional_kwargs [ " function_call " ]
# If function call only, content is None not empty string
if message_dict [ " content " ] == " " :
message_dict [ " content " ] = None
if " tool_calls " in message . additional_kwargs :
message_dict [ " tool_calls " ] = message . additional_kwargs [ " tool_calls " ]
# If tool calls only, content is None not empty string
if message_dict [ " content " ] == " " :
message_dict [ " content " ] = None
elif isinstance ( message , SystemMessage ) :
message_dict = { " role " : " system " , " content " : message . content }
elif isinstance ( message , FunctionMessage ) :
message_dict = {
" role " : " function " ,
" content " : message . content ,
" name " : message . name ,
}
elif isinstance ( message , ToolMessage ) :
message_dict = {
" role " : " tool " ,
" content " : message . content ,
" tool_call_id " : message . tool_call_id ,
}
else :
raise TypeError ( f " Got unknown type { message } " )
if " name " in message . additional_kwargs :
message_dict [ " name " ] = message . additional_kwargs [ " name " ]
return message_dict
def _convert_delta_to_message_chunk (
_dict : Mapping [ str , Any ] , default_class : Type [ BaseMessageChunk ]
) - > BaseMessageChunk :
role = cast ( str , _dict . get ( " role " ) )
content = cast ( str , _dict . get ( " content " ) or " " )
additional_kwargs : Dict = { }
if _dict . get ( " function_call " ) :
function_call = dict ( _dict [ " function_call " ] )
if " name " in function_call and function_call [ " name " ] is None :
function_call [ " name " ] = " "
additional_kwargs [ " function_call " ] = function_call
if _dict . get ( " tool_calls " ) :
additional_kwargs [ " tool_calls " ] = _dict [ " tool_calls " ]
if role == " user " or default_class == HumanMessageChunk :
return HumanMessageChunk ( content = content )
elif role == " assistant " or default_class == AIMessageChunk :
return AIMessageChunk ( content = content , additional_kwargs = additional_kwargs )
elif role == " system " or default_class == SystemMessageChunk :
return SystemMessageChunk ( content = content )
elif role == " function " or default_class == FunctionMessageChunk :
return FunctionMessageChunk ( content = content , name = _dict [ " name " ] )
elif role == " tool " or default_class == ToolMessageChunk :
return ToolMessageChunk ( content = content , tool_call_id = _dict [ " tool_call_id " ] )
elif role or default_class == ChatMessageChunk :
return ChatMessageChunk ( content = content , role = role )
else :
return default_class ( content = content ) # type: ignore
def _convert_dict_to_message ( _dict : Mapping [ str , Any ] ) - > BaseMessage :
""" Convert a dictionary to a LangChain message.
Args :
_dict : The dictionary .
Returns :
The LangChain message .
"""
2024-04-03 21:40:20 +00:00
id_ = _dict . get ( " id " )
2024-02-22 15:36:16 +00:00
role = _dict . get ( " role " )
if role == " user " :
return HumanMessage ( content = _dict . get ( " content " , " " ) )
elif role == " assistant " :
2024-04-03 21:40:20 +00:00
content = _dict . get ( " content " , " " ) or " "
2024-02-22 15:36:16 +00:00
additional_kwargs : Dict = { }
if function_call := _dict . get ( " function_call " ) :
additional_kwargs [ " function_call " ] = dict ( function_call )
if tool_calls := _dict . get ( " tool_calls " ) :
additional_kwargs [ " tool_calls " ] = tool_calls
2024-04-03 21:40:20 +00:00
return AIMessage ( content = content , id = id_ , additional_kwargs = additional_kwargs )
2024-02-22 15:36:16 +00:00
elif role == " system " :
return SystemMessage ( content = _dict . get ( " content " , " " ) )
elif role == " function " :
return FunctionMessage ( content = _dict . get ( " content " , " " ) , name = _dict . get ( " name " ) )
elif role == " tool " :
additional_kwargs = { }
if " name " in _dict :
additional_kwargs [ " name " ] = _dict [ " name " ]
return ToolMessage (
content = _dict . get ( " content " , " " ) ,
tool_call_id = _dict . get ( " tool_call_id " ) ,
additional_kwargs = additional_kwargs ,
)
else :
return ChatMessage ( content = _dict . get ( " content " , " " ) , role = role )