Add type to Generation and sub-classes, handle root validator (#12220)

* Add a type literal for the generation and sub-classes for serialization purposes.
* Fix the root validator of ChatGeneration to return ValueError instead of KeyError or Attribute error if intialized improperly.
* This change is done for langserve to make sure that llm related callbacks can be serialized/deserialized properly.
This commit is contained in:
Eugene Yurtsev 2023-10-24 16:21:00 -04:00 committed by GitHub
parent 81052ee18e
commit 583dc49477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 13 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from uuid import UUID
from langchain.load.serializable import Serializable
@ -19,6 +19,8 @@ class Generation(Serializable):
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes."""
# TODO: add log probs as separate attribute
@classmethod
@ -54,11 +56,17 @@ class ChatGeneration(Generation):
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
values["text"] = values["message"].content
try:
values["text"] = values["message"].content
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
@ -71,6 +79,9 @@ class ChatGenerationChunk(ChatGeneration):
"""
message: BaseMessageChunk
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):

File diff suppressed because one or more lines are too long

View File

@ -1719,7 +1719,9 @@ async def test_prompt_with_llm(
"op": "add",
"path": "/logs/FakeListLLM/final_output",
"value": {
"generations": [[{"generation_info": None, "text": "foo"}]],
"generations": [
[{"generation_info": None, "text": "foo", "type": "Generation"}]
],
"llm_output": None,
"run": None,
},

View File

@ -1,12 +1,19 @@
"""Test formatting functionality."""
import unittest
from typing import Union
import pytest
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
from langchain.pydantic_v1 import BaseModel
from langchain.schema import AgentAction, AgentFinish, Document
from langchain.pydantic_v1 import BaseModel, ValidationError
from langchain.schema import (
AgentAction,
AgentFinish,
ChatGeneration,
Document,
Generation,
)
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import (
AIMessage,
@ -23,6 +30,7 @@ from langchain.schema.messages import (
messages_from_dict,
messages_to_dict,
)
from langchain.schema.output import ChatGenerationChunk
class TestGetBufferString(unittest.TestCase):
@ -108,6 +116,9 @@ def test_serialization_of_wellknown_objects() -> None:
AgentFinish,
AgentAction,
AgentActionMessageLog,
ChatGeneration,
Generation,
ChatGenerationChunk,
]
lc_objects = [
@ -144,6 +155,16 @@ def test_serialization_of_wellknown_objects() -> None:
log="",
message_log=[HumanMessage(content="human")],
),
Generation(
text="hello",
generation_info={"info": "info"},
),
ChatGeneration(
message=HumanMessage(content="human"),
),
ChatGenerationChunk(
message=HumanMessageChunk(content="cat"),
),
]
for lc_object in lc_objects:
@ -151,3 +172,7 @@ def test_serialization_of_wellknown_objects() -> None:
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
obj1 = WellKnownLCObject.parse_obj(d)
assert type(obj1.__root__) == type(lc_object), f"failed for {type(lc_object)}"
with pytest.raises(ValidationError):
# Make sure that specifically validation error is raised
WellKnownLCObject.parse_obj({})