Harrison/official method (#1728)

Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-16 23:20:08 -07:00 committed by GitHub
parent cdff6c8181
commit 276940fd9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 79 additions and 44 deletions

View File

@ -317,3 +317,41 @@ class ChatOpenAI(BaseChatModel, BaseModel):
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_num_tokens_from_messages(
self, messages: List[BaseMessage], model: str = "gpt-3.5-turbo-0301"
) -> int:
"""Calculate num tokens for gpt-3.5-turbo with tiktoken package."""
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install tiktoken`."
)
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo-0301": # note: future models may deviate from this
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
for message in messages_dict:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)

View File

@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.memory.utils import get_prompt_input_key
from langchain.schema import get_buffer_string
class ConversationBufferMemory(BaseChatMemory, BaseModel):

View File

@ -3,8 +3,7 @@ from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.utils import get_buffer_string
from langchain.schema import BaseMessage
from langchain.schema import BaseMessage, get_buffer_string
class ConversationBufferWindowMemory(BaseChatMemory, BaseModel):

View File

@ -8,9 +8,9 @@ from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
ENTITY_SUMMARIZATION_PROMPT,
)
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
class ConversationEntityMemory(BaseChatMemory, BaseModel):

View File

@ -10,9 +10,9 @@ from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, SystemMessage
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string
class ConversationKGMemory(BaseChatMemory, BaseModel):

View File

@ -5,9 +5,13 @@ from pydantic import BaseModel, root_validator
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
from langchain.schema import (
BaseLanguageModel,
BaseMessage,
SystemMessage,
get_buffer_string,
)
class SummarizerMixin(BaseModel):

View File

@ -4,8 +4,7 @@ from pydantic import BaseModel, root_validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
from langchain.memory.utils import get_buffer_string
from langchain.schema import BaseMessage, SystemMessage
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
@ -55,21 +54,17 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
)
return values
def get_num_tokens_list(self, arr: List[BaseMessage]) -> List[int]:
"""Get list of number of tokens in each string in the input array."""
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
self.moving_summary_buffer = self.predict_new_summary(
pruned_memory, self.moving_summary_buffer
)

View File

@ -1,32 +1,6 @@
from typing import Any, Dict, List
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
from langchain.schema import get_buffer_string # noqa: 401
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:

View File

@ -7,6 +7,26 @@ from typing import Any, Dict, List, NamedTuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
class AgentAction(NamedTuple):
"""Agent's action to take."""
@ -185,6 +205,10 @@ class BaseLanguageModel(BaseModel, ABC):
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains."""