mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
fireworks[patch]: add usage_metadata to (a)invoke and (a)stream (#22906)
This commit is contained in:
parent
d1b7a934aa
commit
f40b2c6f9d
@ -179,9 +179,11 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choice = chunk["choices"][0]
|
||||
_dict = choice["delta"]
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
@ -210,10 +212,21 @@ def _convert_delta_to_message_chunk(
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
if usage := chunk.get("usage"):
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
usage_metadata = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
||||
}
|
||||
else:
|
||||
usage_metadata = None
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@ -412,29 +425,29 @@ class ChatFireworks(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||
yield chunk
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -472,8 +485,15 @@ class ChatFireworks(BaseChatModel):
|
||||
generations = []
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
@ -482,7 +502,6 @@ class ChatFireworks(BaseChatModel):
|
||||
generation_info=generation_info,
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_name": self.model_name,
|
||||
@ -500,31 +519,31 @@ class ChatFireworks(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
token=generation_chunk.text,
|
||||
chunk=generation_chunk,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
yield chunk
|
||||
yield generation_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
17
libs/partners/fireworks/poetry.lock
generated
17
libs/partners/fireworks/poetry.lock
generated
@ -572,7 +572,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.4"
|
||||
version = "0.2.6"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -581,15 +581,12 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = "^0.1.66"
|
||||
packaging = "^23.2"
|
||||
langsmith = "^0.1.75"
|
||||
packaging = ">=23.2,<25"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = "^8.1.0"
|
||||
|
||||
[package.extras]
|
||||
extended-testing = ["jinja2 (>=3,<4)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = "../../core"
|
||||
@ -613,13 +610,13 @@ url = "../../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.73"
|
||||
version = "0.1.77"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"},
|
||||
{file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"},
|
||||
{file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"},
|
||||
{file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1551,4 +1548,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "1bd993cb034f7eeb243d4c0861075008065d31c6c707aeb2e99c6214d72fb409"
|
||||
content-hash = "d68944d6707245475f18e901001c98aabdf7aae7944c552e8f058f4806c83f0c"
|
||||
|
@ -12,7 +12,7 @@ license = "MIT"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.2.0,<0.3"
|
||||
langchain-core = ">=0.2.2,<0.3"
|
||||
fireworks-ai = ">=0.13.0"
|
||||
openai = "^1.10.0"
|
||||
requests = "^2"
|
||||
|
@ -4,8 +4,9 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
@ -93,8 +94,28 @@ async def test_astream() -> None:
|
||||
"""Test streaming tokens from ChatFireworks."""
|
||||
llm = ChatFireworks()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
assert isinstance(token.content, str)
|
||||
full = token if full is None else full + token
|
||||
if token.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunks_with_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert (
|
||||
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
||||
== full.usage_metadata["total_tokens"]
|
||||
)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
|
@ -21,17 +21,6 @@ class TestFireworksStandard(ChatModelIntegrationTests):
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user