From 696886f3979009af77c3254c44aca35fa0755c62 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 4 Jul 2023 10:19:08 +0100 Subject: [PATCH] Use serialized format for messages in tracer (#6827) --- langchain/callbacks/tracers/base.py | 13 +++++++++++-- langchain/callbacks/tracers/langchain.py | 11 ++++------- tests/integration_tests/llms/test_openai.py | 4 ++-- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 3322fbffc7..040aef5c08 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -4,12 +4,14 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.tracers.schemas import Run, RunTypeEnum -from langchain.schema import Document, LLMResult +from langchain.load.dump import dumpd +from langchain.schema.document import Document +from langchain.schema.output import ChatGeneration, LLMResult logger = logging.getLogger(__name__) @@ -143,6 +145,13 @@ class BaseTracer(BaseCallbackHandler, ABC): if llm_run is None or llm_run.run_type != RunTypeEnum.llm: raise TracerException("No LLM Run found to be traced") llm_run.outputs = response.dict() + for i, generations in enumerate(response.generations): + for j, generation in enumerate(generations): + output_generation = llm_run.outputs["generations"][i][j] + if "message" in output_generation: + output_generation["message"] = dumpd( + cast(ChatGeneration, generation).message + ) llm_run.end_time = datetime.utcnow() llm_run.events.append({"name": "end", "time": llm_run.end_time}) self._end_trace(llm_run) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index f42f736381..378dd62d0c 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -11,13 +11,10 @@ from uuid import UUID from langchainplus_sdk import LangChainPlusClient from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import ( - Run, - RunTypeEnum, - TracerSession, -) +from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession from langchain.env import get_runtime_environment -from langchain.schema.messages import BaseMessage, messages_to_dict +from langchain.load.dump import dumpd +from langchain.schema.messages import BaseMessage logger = logging.getLogger(__name__) _LOGGED = set() @@ -83,7 +80,7 @@ class LangChainTracer(BaseTracer): id=run_id, parent_run_id=parent_run_id, serialized=serialized, - inputs={"messages": [messages_to_dict(batch) for batch in messages]}, + inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, extra=kwargs, events=[{"name": "start", "time": start_time}], start_time=start_time, diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index f1a146da44..5281b7563a 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -15,8 +15,8 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler def test_openai_call() -> None: """Test valid call to openai.""" - llm = OpenAI(max_tokens=10) - output = llm("Say foo:") + llm = OpenAI(max_tokens=10, n=3) + output = llm("Say something nice:") assert isinstance(output, str)