Update Serializable to use classmethods (#10956)

pull/11025/head
Eugene Yurtsev 11 months ago committed by GitHub
parent b7290f01d8
commit 09486ed188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,8 +42,8 @@ class LLMChain(Chain):
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate

@ -99,8 +99,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
"""Return type of chat model."""
return "anthropic-chat"
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:

@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import os
import sys
from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Dict, Optional, Set
import requests
@ -50,7 +50,7 @@ class ChatAnyscale(ChatOpenAI):
return "anyscale-chat"
@property
def lc_secrets(self) -> dict[str, str]:
def lc_secrets(self) -> Dict[str, str]:
return {"anyscale_api_key": "ANYSCALE_API_KEY"}
anyscale_api_key: Optional[str] = None

@ -290,7 +290,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
return {**params, **kwargs}
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.lc_serializable:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)

@ -164,8 +164,9 @@ class JinaChat(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]:
return {"jinachat_api_key": "JINACHAT_API_KEY"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
client: Any #: :meta private:

@ -55,8 +55,9 @@ class ChatKonko(ChatOpenAI):
def lc_secrets(self) -> Dict[str, str]:
return {"konko_api_key": "KONKO_API_KEY", "openai_api_key": "OPENAI_API_KEY"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
client: Any = None #: :meta private:

@ -47,8 +47,9 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
"""Return type of chat model."""
return "ollama-chat"
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
def _format_message_as_text(self, message: BaseMessage) -> str:

@ -140,8 +140,9 @@ class ChatOpenAI(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
client: Any = None #: :meta private:

@ -51,8 +51,8 @@ class BaseFireworks(BaseLLM):
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def __new__(cls, **data: Any) -> Any:

@ -138,8 +138,8 @@ class BaseOpenAI(BaseLLM):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
client: Any = None #: :meta private:

@ -65,8 +65,8 @@ class Replicate(LLM):
def lc_secrets(self) -> Dict[str, str]:
return {"replicate_api_token": "REPLICATE_API_TOKEN"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@root_validator(pre=True)

@ -102,8 +102,8 @@ class Tongyi(LLM):
def lc_secrets(self) -> Dict[str, str]:
return {"dashscope_api_key": "DASHSCOPE_API_KEY"}
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
client: Any #: :meta private:

@ -34,20 +34,19 @@ class SerializedNotImplemented(BaseSerialized):
class Serializable(BaseModel, ABC):
"""Serializable base class."""
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Is this class serializable?"""
return False
@property
def lc_namespace(self) -> List[str]:
"""
Return the namespace of the langchain object.
eg. ["langchain", "llms", "openai"]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the
namespace is ["langchain", "llms", "openai"]
"""
return self.__class__.__module__.split(".")
return cls.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
@ -76,7 +75,7 @@ class Serializable(BaseModel, ABC):
self._lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.lc_serializable:
if not self.is_lc_serializable():
return self.to_json_not_implemented()
secrets = dict()
@ -93,6 +92,20 @@ class Serializable(BaseModel, ABC):
if cls is Serializable:
break
if cls:
deprecated_attributes = [
"lc_namespace",
"lc_serializable",
]
for attr in deprecated_attributes:
if hasattr(cls, attr):
raise ValueError(
f"Class {self.__class__} has a deprecated "
f"attribute {attr}. Please use the corresponding "
f"classmethod instead."
)
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
@ -109,7 +122,7 @@ class Serializable(BaseModel, ABC):
return {
"lc": 1,
"type": "constructor",
"id": [*self.lc_namespace, self.__class__.__name__],
"id": [*self.get_lc_namespace(), self.__class__.__name__],
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),

@ -9,8 +9,8 @@ from langchain.schema import BaseOutputParser
class CombiningOutputParser(BaseOutputParser):
"""Combine multiple output parsers into one."""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
parsers: List[BaseOutputParser]

@ -13,8 +13,8 @@ T = TypeVar("T")
class OutputFixingParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors."""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
parser: BaseOutputParser[T]

@ -22,8 +22,8 @@ class ListOutputParser(BaseOutputParser[List[str]]):
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def get_format_instructions(self) -> str:

@ -9,8 +9,8 @@ from langchain.schema import BaseOutputParser
class RegexParser(BaseOutputParser):
"""Parse the output of an LLM call using a regex."""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
regex: str

@ -39,13 +39,9 @@ from langchain.schema.messages import (
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@property
def lc_serializable(self) -> bool:
"""Whether this object should be serialized.
Returns:
Whether this object should be serialized.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@abstractmethod

@ -71,13 +71,9 @@ class _FewShotPromptTemplateMixin(BaseModel):
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Prompt template that contains few shot examples."""
@property
def lc_serializable(self) -> bool:
"""Return whether the prompt template is lc_serializable.
Returns:
Boolean indicating whether the prompt template is lc_serializable.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False
validate_template: bool = True
@ -278,13 +274,9 @@ class FewShotChatMessagePromptTemplate(
chain.invoke({"input": "What's 3+3?"})
"""
@property
def lc_serializable(self) -> bool:
"""Return whether the prompt template is lc_serializable.
Returns:
Boolean indicating whether the prompt template is lc_serializable.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False
input_variables: List[str] = Field(default_factory=list)

@ -27,11 +27,9 @@ class AgentAction(Serializable):
):
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@ -62,9 +60,7 @@ class AgentFinish(Serializable):
def __init__(self, return_values: dict, log: str, **kwargs: Any):
super().__init__(return_values=return_values, log=log, **kwargs)
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True

@ -17,9 +17,9 @@ class Document(Serializable):
documents, etc.).
"""
@property
def lc_serializable(self) -> bool:
"""Return whether or not the class is serializable."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True

@ -74,9 +74,9 @@ class BaseMessage(Serializable):
def type(self) -> str:
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:

@ -21,9 +21,9 @@ class Generation(Serializable):
"""
# TODO: add log probs as separate attribute
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True

@ -304,9 +304,9 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@property
def lc_serializable(self) -> bool:
"""Whether the class LangChain serializable."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@property

@ -14,11 +14,9 @@ class PromptValue(Serializable, ABC):
ChatModel inputs.
"""
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@abstractmethod

@ -26,8 +26,9 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
default_factory=dict
)
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
class Config:

@ -834,15 +834,15 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@property
def lc_namespace(self) -> List[str]:
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return self.__class__.__module__.split(".")[:-1]
return cls.__module__.split(".")[:-1]
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
@ -946,13 +946,13 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
@ -1184,13 +1184,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def steps(self) -> List[Runnable[Any, Any]]:
return [self.first] + self.middle + [self.last]
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
class Config:
arbitrary_types_allowed = True
@ -1674,13 +1674,13 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
) -> None:
super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()})
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
class Config:
arbitrary_types_allowed = True
@ -2061,13 +2061,13 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs))
@ -2117,13 +2117,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))

@ -20,13 +20,13 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config)

@ -66,13 +66,14 @@ class RouterRunnable(
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def __or__(
self,

@ -19,8 +19,8 @@ class Person(Serializable):
you_can_see_me: str = "hello"
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property

@ -1694,8 +1694,9 @@ async def test_llm_with_fallbacks(
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a comma-separated list."""
@property
def lc_serializable(self) -> bool:
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
def get_format_instructions(self) -> str:

Loading…
Cancel
Save