forked from Archives/langchain
Harrison/official method (#1728)
Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com>
This commit is contained in:
parent
cdff6c8181
commit
276940fd9b
@ -317,3 +317,41 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self, messages: List[BaseMessage], model: str = "gpt-3.5-turbo-0301"
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo with tiktoken package."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please it install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo-0301": # note: future models may deviate from this
|
||||
num_tokens = 0
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
num_tokens += 4
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name": # if there's a name, the role is omitted
|
||||
num_tokens += -1 # role is always required and always 1 token
|
||||
num_tokens += 2 # every reply is primed with <im_start>assistant
|
||||
return num_tokens
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
|
@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema import get_buffer_string
|
||||
|
||||
|
||||
class ConversationBufferMemory(BaseChatMemory, BaseModel):
|
||||
|
@ -3,8 +3,7 @@ from typing import Any, Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.utils import get_buffer_string
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationBufferWindowMemory(BaseChatMemory, BaseModel):
|
||||
|
@ -8,9 +8,9 @@ from langchain.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
ENTITY_SUMMARIZATION_PROMPT,
|
||||
)
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
||||
|
@ -10,9 +10,9 @@ from langchain.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
)
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, SystemMessage
|
||||
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||
|
@ -5,9 +5,13 @@ from pydantic import BaseModel, root_validator
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.memory.utils import get_buffer_string
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
class SummarizerMixin(BaseModel):
|
||||
|
@ -4,8 +4,7 @@ from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.memory.utils import get_buffer_string
|
||||
from langchain.schema import BaseMessage, SystemMessage
|
||||
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
||||
@ -55,21 +54,17 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
||||
)
|
||||
return values
|
||||
|
||||
def get_num_tokens_list(self, arr: List[BaseMessage]) -> List[int]:
|
||||
"""Get list of number of tokens in each string in the input array."""
|
||||
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
pruned_memory, self.moving_summary_buffer
|
||||
)
|
||||
|
@ -1,32 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Get buffer string of messages."""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
string_messages.append(f"{role}: {m.content}")
|
||||
return "\n".join(string_messages)
|
||||
from langchain.schema import get_buffer_string # noqa: 401
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
|
@ -7,6 +7,26 @@ from typing import Any, Dict, List, NamedTuple, Optional
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Get buffer string of messages."""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
string_messages.append(f"{role}: {m.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
class AgentAction(NamedTuple):
|
||||
"""Agent's action to take."""
|
||||
|
||||
@ -185,6 +205,10 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
|
||||
class BaseMemory(BaseModel, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
|
Loading…
Reference in New Issue
Block a user