|
|
|
@ -307,7 +307,7 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
)
|
|
|
|
|
chat_result = self._create_chat_result(response)
|
|
|
|
|
generation = chat_result.generations[0]
|
|
|
|
|
message = generation.message
|
|
|
|
|
message = cast(AIMessage, generation.message)
|
|
|
|
|
tool_call_chunks = [
|
|
|
|
|
{
|
|
|
|
|
"name": rtc["function"].get("name"),
|
|
|
|
@ -322,6 +322,7 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
content=message.content,
|
|
|
|
|
additional_kwargs=message.additional_kwargs,
|
|
|
|
|
tool_call_chunks=tool_call_chunks,
|
|
|
|
|
usage_metadata=message.usage_metadata,
|
|
|
|
|
),
|
|
|
|
|
generation_info=generation.generation_info,
|
|
|
|
|
)
|
|
|
|
@ -337,30 +338,30 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
|
|
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
|
|
|
|
for chunk in self.client.create(messages=message_dicts, **params):
|
|
|
|
|
if not isinstance(chunk, dict):
|
|
|
|
|
chunk = chunk.dict()
|
|
|
|
|
if len(chunk["choices"]) == 0:
|
|
|
|
|
continue
|
|
|
|
|
choice = chunk["choices"][0]
|
|
|
|
|
chunk = _convert_delta_to_message_chunk(
|
|
|
|
|
choice["delta"], default_chunk_class
|
|
|
|
|
)
|
|
|
|
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
|
|
|
|
generation_info = {}
|
|
|
|
|
if finish_reason := choice.get("finish_reason"):
|
|
|
|
|
generation_info["finish_reason"] = finish_reason
|
|
|
|
|
logprobs = choice.get("logprobs")
|
|
|
|
|
if logprobs:
|
|
|
|
|
generation_info["logprobs"] = logprobs
|
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
|
chunk = ChatGenerationChunk(
|
|
|
|
|
message=chunk, generation_info=generation_info or None
|
|
|
|
|
default_chunk_class = message_chunk.__class__
|
|
|
|
|
generation_chunk = ChatGenerationChunk(
|
|
|
|
|
message=message_chunk, generation_info=generation_info or None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
|
|
|
|
yield chunk
|
|
|
|
|
run_manager.on_llm_new_token(
|
|
|
|
|
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
|
|
|
|
)
|
|
|
|
|
yield generation_chunk
|
|
|
|
|
|
|
|
|
|
async def _astream(
|
|
|
|
|
self,
|
|
|
|
@ -378,7 +379,7 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
)
|
|
|
|
|
chat_result = self._create_chat_result(response)
|
|
|
|
|
generation = chat_result.generations[0]
|
|
|
|
|
message = generation.message
|
|
|
|
|
message = cast(AIMessage, generation.message)
|
|
|
|
|
tool_call_chunks = [
|
|
|
|
|
{
|
|
|
|
|
"name": rtc["function"].get("name"),
|
|
|
|
@ -393,6 +394,7 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
content=message.content,
|
|
|
|
|
additional_kwargs=message.additional_kwargs,
|
|
|
|
|
tool_call_chunks=tool_call_chunks,
|
|
|
|
|
usage_metadata=message.usage_metadata,
|
|
|
|
|
),
|
|
|
|
|
generation_info=generation.generation_info,
|
|
|
|
|
)
|
|
|
|
@ -408,7 +410,7 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
|
|
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
|
|
|
|
async for chunk in await self.async_client.create(
|
|
|
|
|
messages=message_dicts, **params
|
|
|
|
|
):
|
|
|
|
@ -417,25 +419,25 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
if len(chunk["choices"]) == 0:
|
|
|
|
|
continue
|
|
|
|
|
choice = chunk["choices"][0]
|
|
|
|
|
chunk = _convert_delta_to_message_chunk(
|
|
|
|
|
choice["delta"], default_chunk_class
|
|
|
|
|
)
|
|
|
|
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
|
|
|
|
generation_info = {}
|
|
|
|
|
if finish_reason := choice.get("finish_reason"):
|
|
|
|
|
generation_info["finish_reason"] = finish_reason
|
|
|
|
|
logprobs = choice.get("logprobs")
|
|
|
|
|
if logprobs:
|
|
|
|
|
generation_info["logprobs"] = logprobs
|
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
|
chunk = ChatGenerationChunk(
|
|
|
|
|
message=chunk, generation_info=generation_info or None
|
|
|
|
|
default_chunk_class = message_chunk.__class__
|
|
|
|
|
generation_chunk = ChatGenerationChunk(
|
|
|
|
|
message=message_chunk, generation_info=generation_info or None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if run_manager:
|
|
|
|
|
await run_manager.on_llm_new_token(
|
|
|
|
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
|
|
|
|
token=generation_chunk.text,
|
|
|
|
|
chunk=generation_chunk,
|
|
|
|
|
logprobs=logprobs,
|
|
|
|
|
)
|
|
|
|
|
yield chunk
|
|
|
|
|
yield generation_chunk
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# Internal methods
|
|
|
|
@ -459,8 +461,19 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
generations = []
|
|
|
|
|
if not isinstance(response, dict):
|
|
|
|
|
response = response.dict()
|
|
|
|
|
token_usage = response.get("usage", {})
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
message = _convert_dict_to_message(res["message"])
|
|
|
|
|
if token_usage and isinstance(message, AIMessage):
|
|
|
|
|
input_tokens = token_usage.get("prompt_tokens", 0)
|
|
|
|
|
output_tokens = token_usage.get("completion_tokens", 0)
|
|
|
|
|
message.usage_metadata = {
|
|
|
|
|
"input_tokens": input_tokens,
|
|
|
|
|
"output_tokens": output_tokens,
|
|
|
|
|
"total_tokens": token_usage.get(
|
|
|
|
|
"total_tokens", input_tokens + output_tokens
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
generation_info = dict(finish_reason=res.get("finish_reason"))
|
|
|
|
|
if "logprobs" in res:
|
|
|
|
|
generation_info["logprobs"] = res["logprobs"]
|
|
|
|
@ -469,7 +482,6 @@ class ChatGroq(BaseChatModel):
|
|
|
|
|
generation_info=generation_info,
|
|
|
|
|
)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
token_usage = response.get("usage", {})
|
|
|
|
|
llm_output = {
|
|
|
|
|
"token_usage": token_usage,
|
|
|
|
|
"model_name": self.model_name,
|
|
|
|
@ -892,9 +904,11 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
|
|
|
return message_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_delta_to_message_chunk(
|
|
|
|
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
|
|
|
|
def _convert_chunk_to_message_chunk(
|
|
|
|
|
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
|
|
|
|
) -> BaseMessageChunk:
|
|
|
|
|
choice = chunk["choices"][0]
|
|
|
|
|
_dict = choice["delta"]
|
|
|
|
|
role = cast(str, _dict.get("role"))
|
|
|
|
|
content = cast(str, _dict.get("content") or "")
|
|
|
|
|
additional_kwargs: Dict = {}
|
|
|
|
@ -909,7 +923,21 @@ def _convert_delta_to_message_chunk(
|
|
|
|
|
if role == "user" or default_class == HumanMessageChunk:
|
|
|
|
|
return HumanMessageChunk(content=content)
|
|
|
|
|
elif role == "assistant" or default_class == AIMessageChunk:
|
|
|
|
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
|
|
|
|
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
|
|
|
|
input_tokens = usage.get("prompt_tokens", 0)
|
|
|
|
|
output_tokens = usage.get("completion_tokens", 0)
|
|
|
|
|
usage_metadata = {
|
|
|
|
|
"input_tokens": input_tokens,
|
|
|
|
|
"output_tokens": output_tokens,
|
|
|
|
|
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
usage_metadata = None
|
|
|
|
|
return AIMessageChunk(
|
|
|
|
|
content=content,
|
|
|
|
|
additional_kwargs=additional_kwargs,
|
|
|
|
|
usage_metadata=usage_metadata,
|
|
|
|
|
)
|
|
|
|
|
elif role == "system" or default_class == SystemMessageChunk:
|
|
|
|
|
return SystemMessageChunk(content=content)
|
|
|
|
|
elif role == "function" or default_class == FunctionMessageChunk:
|
|
|
|
|