Harrison/chat token usage (#1785)

tool-patch
Harrison Chase 1 year ago committed by GitHub
parent 7de2ada3ea
commit ef4945af6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "522686de",
"metadata": {
"tags": []
@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "62e0dbc3",
"metadata": {
"tags": []
@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "76a6e7b0-e927-4bfb-a414-1332a4149106",
"metadata": {
"tags": []
@ -68,7 +68,7 @@
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -87,7 +87,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
"metadata": {
"tags": []
@ -99,7 +99,7 @@
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -122,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "2b21fc52-74b6-4950-ab78-45d12c68fb4d",
"metadata": {
"tags": []
@ -131,10 +131,10 @@
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output=None)"
"LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output={'token_usage': {'prompt_tokens': 71, 'completion_tokens': 18, 'total_tokens': 89}})"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -150,7 +150,39 @@
" HumanMessage(content=\"Translate this sentence from English to French. I love artificial intelligence.\")\n",
" ],\n",
"]\n",
"chat.generate(batch_messages)"
"result = chat.generate(batch_messages)\n",
"result"
]
},
{
"cell_type": "markdown",
"id": "2960f50f",
"metadata": {},
"source": [
"You can recover things like token usage from this LLMResult"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a6186bee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'token_usage': {'prompt_tokens': 71,\n",
" 'completion_tokens': 18,\n",
" 'total_tokens': 89}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result.llm_output"
]
},
{

@ -43,19 +43,26 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
"""
return callback_manager or get_callback_manager()
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
"""Top Level call"""
results = [self._generate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
return LLMResult(generations=generations, llm_output=llm_output)
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
"""Top Level call"""
results = [await self._agenerate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
return LLMResult(generations=generations, llm_output=llm_output)
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None

@ -97,7 +97,8 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
llm_output = {"token_usage": response["usage"]}
return ChatResult(generations=generations, llm_output=llm_output)
class ChatOpenAI(BaseChatModel, BaseModel):
@ -221,6 +222,19 @@ class ChatOpenAI(BaseChatModel, BaseModel):
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
raise ValueError("Should always be something for OpenAI.")
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage}
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:

Loading…
Cancel
Save