diff --git a/docs/modules/chat/getting_started.ipynb b/docs/modules/chat/getting_started.ipynb index 98662a90..113d652e 100644 --- a/docs/modules/chat/getting_started.ipynb +++ b/docs/modules/chat/getting_started.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "522686de", "metadata": { "tags": [] @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "62e0dbc3", "metadata": { "tags": [] @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "76a6e7b0-e927-4bfb-a414-1332a4149106", "metadata": { "tags": [] @@ -68,7 +68,7 @@ "AIMessage(content=\"J'aime programmer.\", additional_kwargs={})" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -87,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "ce16ad78-8e6f-48cd-954e-98be75eb5836", "metadata": { "tags": [] @@ -99,7 +99,7 @@ "AIMessage(content=\"J'aime programmer.\", additional_kwargs={})" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -122,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "2b21fc52-74b6-4950-ab78-45d12c68fb4d", "metadata": { "tags": [] @@ -131,10 +131,10 @@ { "data": { "text/plain": [ - "LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output=None)" + "LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output={'token_usage': {'prompt_tokens': 71, 'completion_tokens': 18, 'total_tokens': 89}})" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -150,7 +150,39 @@ " HumanMessage(content=\"Translate this sentence from English to French. I love artificial intelligence.\")\n", " ],\n", "]\n", - "chat.generate(batch_messages)" + "result = chat.generate(batch_messages)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "id": "2960f50f", + "metadata": {}, + "source": [ + "You can recover things like token usage from this LLMResult" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a6186bee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'token_usage': {'prompt_tokens': 71,\n", + " 'completion_tokens': 18,\n", + " 'total_tokens': 89}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result.llm_output" ] }, { diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index ffb7c03c..1bf9d4fa 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -43,19 +43,26 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC): """ return callback_manager or get_callback_manager() + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + return {} + def generate( self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None ) -> LLMResult: """Top Level call""" results = [self._generate(m, stop=stop) for m in messages] - return LLMResult(generations=[res.generations for res in results]) + llm_output = self._combine_llm_outputs([res.llm_output for res in results]) + generations = [res.generations for res in results] + return LLMResult(generations=generations, llm_output=llm_output) async def agenerate( self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None ) -> LLMResult: """Top Level call""" results = [await self._agenerate(m, stop=stop) for m in messages] - return LLMResult(generations=[res.generations for res in results]) + llm_output = self._combine_llm_outputs([res.llm_output for res in results]) + generations = [res.generations for res in results] + return LLMResult(generations=generations, llm_output=llm_output) def generate_prompt( self, prompts: List[PromptValue], stop: Optional[List[str]] = None diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 45a495af..62736b53 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -97,7 +97,8 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: message = _convert_dict_to_message(res["message"]) gen = ChatGeneration(message=message) generations.append(gen) - return ChatResult(generations=generations) + llm_output = {"token_usage": response["usage"]} + return ChatResult(generations=generations, llm_output=llm_output) class ChatOpenAI(BaseChatModel, BaseModel): @@ -221,6 +222,19 @@ class ChatOpenAI(BaseChatModel, BaseModel): return _completion_with_retry(**kwargs) + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + raise ValueError("Should always be something for OpenAI.") + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage} + def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> ChatResult: