You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/schema.py

393 lines
11 KiB
Python

"""Common schema objects."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: str
log: str
class AgentFinish(NamedTuple):
"""Agent's return value."""
return_values: dict
log: str
class Generation(BaseModel):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
class BaseMessage(BaseModel):
"""Message object."""
content: str
additional_kwargs: dict = Field(default_factory=dict)
@property
@abstractmethod
def type(self) -> str:
"""Type of the message, used for serialization."""
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker."""
role: str
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
class ChatGeneration(Generation):
"""Output of a single generation."""
text = ""
message: BaseMessage
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["text"] = values["message"].content
return values
class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""
generations: List[ChatGeneration]
"""List of the things generated."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""Input keys this memory class will load dynamically."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain.
If None, return all memories
"""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
class BaseChatMessageHistory(ABC):
"""Base interface for chat message history
See `ChatMessageHistory` for default implementation.
"""
"""
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_user_message(self, message: str):
message_ = HumanMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def add_ai_message(self, message: str):
message_ = AIMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
@abstractmethod
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
@abstractmethod
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
metadata: dict = Field(default_factory=dict)
class BaseRetriever(ABC):
@abstractmethod
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
@abstractmethod
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
# For backwards compatibility
Memory = BaseMemory
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of language model )
and parses it into some structure.
Args:
text: output of language model
Returns:
structured output
"""
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Optional method to parse the output of an LLM call with a prompt.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: output of language model
prompt: prompt value
Returns:
structured output
"""
return self.parse(completion)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class OutputParserException(Exception):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
"""
pass