mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
Add type to Generation and sub-classes, handle root validator (#12220)
* Add a type literal for the generation and sub-classes for serialization purposes. * Fix the root validator of ChatGeneration to return ValueError instead of KeyError or Attribute error if intialized improperly. * This change is done for langserve to make sure that llm related callbacks can be serialized/deserialized properly.
This commit is contained in:
parent
81052ee18e
commit
583dc49477
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
@ -19,6 +19,8 @@ class Generation(Serializable):
|
||||
"""Raw response from the provider. May include things like the
|
||||
reason for finishing or token log probabilities.
|
||||
"""
|
||||
type: Literal["Generation"] = "Generation"
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
# TODO: add log probs as separate attribute
|
||||
|
||||
@classmethod
|
||||
@ -54,11 +56,17 @@ class ChatGeneration(Generation):
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@root_validator
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
values["text"] = values["message"].content
|
||||
try:
|
||||
values["text"] = values["message"].content
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
|
||||
|
||||
@ -71,6 +79,9 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
"""
|
||||
|
||||
message: BaseMessageChunk
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||
if isinstance(other, ChatGenerationChunk):
|
||||
|
File diff suppressed because one or more lines are too long
@ -1719,7 +1719,9 @@ async def test_prompt_with_llm(
|
||||
"op": "add",
|
||||
"path": "/logs/FakeListLLM/final_output",
|
||||
"value": {
|
||||
"generations": [[{"generation_info": None, "text": "foo"}]],
|
||||
"generations": [
|
||||
[{"generation_info": None, "text": "foo", "type": "Generation"}]
|
||||
],
|
||||
"llm_output": None,
|
||||
"run": None,
|
||||
},
|
||||
|
@ -1,12 +1,19 @@
|
||||
"""Test formatting functionality."""
|
||||
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import AgentAction, AgentFinish, Document
|
||||
from langchain.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
ChatGeneration,
|
||||
Document,
|
||||
Generation,
|
||||
)
|
||||
from langchain.schema.agent import AgentActionMessageLog
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@ -23,6 +30,7 @@ from langchain.schema.messages import (
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
|
||||
|
||||
class TestGetBufferString(unittest.TestCase):
|
||||
@ -108,6 +116,9 @@ def test_serialization_of_wellknown_objects() -> None:
|
||||
AgentFinish,
|
||||
AgentAction,
|
||||
AgentActionMessageLog,
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
ChatGenerationChunk,
|
||||
]
|
||||
|
||||
lc_objects = [
|
||||
@ -144,6 +155,16 @@ def test_serialization_of_wellknown_objects() -> None:
|
||||
log="",
|
||||
message_log=[HumanMessage(content="human")],
|
||||
),
|
||||
Generation(
|
||||
text="hello",
|
||||
generation_info={"info": "info"},
|
||||
),
|
||||
ChatGeneration(
|
||||
message=HumanMessage(content="human"),
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=HumanMessageChunk(content="cat"),
|
||||
),
|
||||
]
|
||||
|
||||
for lc_object in lc_objects:
|
||||
@ -151,3 +172,7 @@ def test_serialization_of_wellknown_objects() -> None:
|
||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||
obj1 = WellKnownLCObject.parse_obj(d)
|
||||
assert type(obj1.__root__) == type(lc_object), f"failed for {type(lc_object)}"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
# Make sure that specifically validation error is raised
|
||||
WellKnownLCObject.parse_obj({})
|
||||
|
Loading…
Reference in New Issue
Block a user