From 276940fd9babf8aec570dd869cc84fbca1c766bf Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 16 Mar 2023 23:20:08 -0700 Subject: [PATCH] Harrison/official method (#1728) Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com> --- langchain/chat_models/openai.py | 38 ++++++++++++++++++++++++++++++ langchain/memory/buffer.py | 3 ++- langchain/memory/buffer_window.py | 3 +-- langchain/memory/entity.py | 4 ++-- langchain/memory/kg.py | 4 ++-- langchain/memory/summary.py | 8 +++++-- langchain/memory/summary_buffer.py | 11 +++------ langchain/memory/utils.py | 28 +--------------------- langchain/schema.py | 24 +++++++++++++++++++ 9 files changed, 79 insertions(+), 44 deletions(-) diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 5dd1760c..f6e0f381 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -317,3 +317,41 @@ class ChatOpenAI(BaseChatModel, BaseModel): # calculate the number of tokens in the encoded text return len(tokenized_text) + + def get_num_tokens_from_messages( + self, messages: List[BaseMessage], model: str = "gpt-3.5-turbo-0301" + ) -> int: + """Calculate num tokens for gpt-3.5-turbo with tiktoken package.""" + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_num_tokens. " + "Please it install it with `pip install tiktoken`." + ) + + """Returns the number of tokens used by a list of messages.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + if model == "gpt-3.5-turbo-0301": # note: future models may deviate from this + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + # every message follows {role/name}\n{content}\n + num_tokens += 4 + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": # if there's a name, the role is omitted + num_tokens += -1 # role is always required and always 1 token + num_tokens += 2 # every reply is primed with assistant + return num_tokens + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) diff --git a/langchain/memory/buffer.py b/langchain/memory/buffer.py index 8ca19e9b..0e197f84 100644 --- a/langchain/memory/buffer.py +++ b/langchain/memory/buffer.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, root_validator from langchain.memory.chat_memory import BaseChatMemory, BaseMemory -from langchain.memory.utils import get_buffer_string, get_prompt_input_key +from langchain.memory.utils import get_prompt_input_key +from langchain.schema import get_buffer_string class ConversationBufferMemory(BaseChatMemory, BaseModel): diff --git a/langchain/memory/buffer_window.py b/langchain/memory/buffer_window.py index 9d94b9d1..d76faadd 100644 --- a/langchain/memory/buffer_window.py +++ b/langchain/memory/buffer_window.py @@ -3,8 +3,7 @@ from typing import Any, Dict, List from pydantic import BaseModel from langchain.memory.chat_memory import BaseChatMemory -from langchain.memory.utils import get_buffer_string -from langchain.schema import BaseMessage +from langchain.schema import BaseMessage, get_buffer_string class ConversationBufferWindowMemory(BaseChatMemory, BaseModel): diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 06fe1107..73f0bc15 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -8,9 +8,9 @@ from langchain.memory.prompt import ( ENTITY_EXTRACTION_PROMPT, ENTITY_SUMMARIZATION_PROMPT, ) -from langchain.memory.utils import get_buffer_string, get_prompt_input_key +from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseMessage +from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string class ConversationEntityMemory(BaseChatMemory, BaseModel): diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index a6a23402..7d23f60b 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -10,9 +10,9 @@ from langchain.memory.prompt import ( ENTITY_EXTRACTION_PROMPT, KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, ) -from langchain.memory.utils import get_buffer_string, get_prompt_input_key +from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, SystemMessage +from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string class ConversationKGMemory(BaseChatMemory, BaseModel): diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 136c4f73..082e3a43 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -5,9 +5,13 @@ from pydantic import BaseModel, root_validator from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.memory.utils import get_buffer_string from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage +from langchain.schema import ( + BaseLanguageModel, + BaseMessage, + SystemMessage, + get_buffer_string, +) class SummarizerMixin(BaseModel): diff --git a/langchain/memory/summary_buffer.py b/langchain/memory/summary_buffer.py index 44049b9d..9f0f5c6b 100644 --- a/langchain/memory/summary_buffer.py +++ b/langchain/memory/summary_buffer.py @@ -4,8 +4,7 @@ from pydantic import BaseModel, root_validator from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.summary import SummarizerMixin -from langchain.memory.utils import get_buffer_string -from langchain.schema import BaseMessage, SystemMessage +from langchain.schema import BaseMessage, SystemMessage, get_buffer_string class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel): @@ -55,21 +54,17 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel ) return values - def get_num_tokens_list(self, arr: List[BaseMessage]) -> List[int]: - """Get list of number of tokens in each string in the input array.""" - return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr] - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" super().save_context(inputs, outputs) # Prune buffer if it exceeds max token limit buffer = self.chat_memory.messages - curr_buffer_length = sum(self.get_num_tokens_list(buffer)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) if curr_buffer_length > self.max_token_limit: pruned_memory = [] while curr_buffer_length > self.max_token_limit: pruned_memory.append(buffer.pop(0)) - curr_buffer_length = sum(self.get_num_tokens_list(buffer)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) self.moving_summary_buffer = self.predict_new_summary( pruned_memory, self.moving_summary_buffer ) diff --git a/langchain/memory/utils.py b/langchain/memory/utils.py index 5b6f0044..ecff2624 100644 --- a/langchain/memory/utils.py +++ b/langchain/memory/utils.py @@ -1,32 +1,6 @@ from typing import Any, Dict, List -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) - - -def get_buffer_string( - messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" -) -> str: - """Get buffer string of messages.""" - string_messages = [] - for m in messages: - if isinstance(m, HumanMessage): - role = human_prefix - elif isinstance(m, AIMessage): - role = ai_prefix - elif isinstance(m, SystemMessage): - role = "System" - elif isinstance(m, ChatMessage): - role = m.role - else: - raise ValueError(f"Got unsupported message type: {m}") - string_messages.append(f"{role}: {m.content}") - return "\n".join(string_messages) +from langchain.schema import get_buffer_string # noqa: 401 def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: diff --git a/langchain/schema.py b/langchain/schema.py index 50620f18..0a4e5ef5 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -7,6 +7,26 @@ from typing import Any, Dict, List, NamedTuple, Optional from pydantic import BaseModel, Extra, Field, root_validator +def get_buffer_string( + messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: + """Get buffer string of messages.""" + string_messages = [] + for m in messages: + if isinstance(m, HumanMessage): + role = human_prefix + elif isinstance(m, AIMessage): + role = ai_prefix + elif isinstance(m, SystemMessage): + role = "System" + elif isinstance(m, ChatMessage): + role = m.role + else: + raise ValueError(f"Got unsupported message type: {m}") + string_messages.append(f"{role}: {m.content}") + return "\n".join(string_messages) + + class AgentAction(NamedTuple): """Agent's action to take.""" @@ -185,6 +205,10 @@ class BaseLanguageModel(BaseModel, ABC): # calculate the number of tokens in the tokenized text return len(tokenized_text) + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the message.""" + return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) + class BaseMemory(BaseModel, ABC): """Base interface for memory in chains."""