diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index 49dc96b381..17ce470053 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -23,7 +23,24 @@ class ChatGeneration(Generation): def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Set the text attribute to be the contents of the message.""" try: - values["text"] = values["message"].content + text = "" + if isinstance(values["message"].content, str): + text = values["message"].content + # HACK: Assumes text in content blocks in OpenAI format. + # Uses first text block. + elif isinstance(values["message"].content, list): + for block in values["message"].content: + if isinstance(block, str): + text = block + break + elif isinstance(block, dict) and "text" in block: + text = block["text"] + break + else: + pass + else: + pass + values["text"] = text except (KeyError, AttributeError) as e: raise ValueError("Error while initializing ChatGeneration") from e return values diff --git a/libs/core/tests/unit_tests/outputs/test_chat_generation.py b/libs/core/tests/unit_tests/outputs/test_chat_generation.py new file mode 100644 index 0000000000..c409a76f0d --- /dev/null +++ b/libs/core/tests/unit_tests/outputs/test_chat_generation.py @@ -0,0 +1,32 @@ +from typing import Union + +import pytest + +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration + + +@pytest.mark.parametrize( + "content", + [ + "foo", + ["foo"], + [{"text": "foo", "type": "text"}], + [ + {"tool_use": {}, "type": "tool_use"}, + {"text": "foo", "type": "text"}, + "bar", + ], + ], +) +def test_msg_with_text(content: Union[str, list]) -> None: + expected = "foo" + actual = ChatGeneration(message=AIMessage(content=content)).text + assert actual == expected + + +@pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]]) +def test_msg_no_text(content: Union[str, list]) -> None: + expected = "" + actual = ChatGeneration(message=AIMessage(content=content)).text + assert actual == expected