forked from Archives/langchain
Harrison/official method (#1728)
Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com>
This commit is contained in:
parent
cdff6c8181
commit
276940fd9b
@ -317,3 +317,41 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
|
|
||||||
# calculate the number of tokens in the encoded text
|
# calculate the number of tokens in the encoded text
|
||||||
return len(tokenized_text)
|
return len(tokenized_text)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(
|
||||||
|
self, messages: List[BaseMessage], model: str = "gpt-3.5-turbo-0301"
|
||||||
|
) -> int:
|
||||||
|
"""Calculate num tokens for gpt-3.5-turbo with tiktoken package."""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_num_tokens. "
|
||||||
|
"Please it install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
|
||||||
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
if model == "gpt-3.5-turbo-0301": # note: future models may deviate from this
|
||||||
|
num_tokens = 0
|
||||||
|
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
for message in messages_dict:
|
||||||
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
|
num_tokens += 4
|
||||||
|
for key, value in message.items():
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name": # if there's a name, the role is omitted
|
||||||
|
num_tokens += -1 # role is always required and always 1 token
|
||||||
|
num_tokens += 2 # every reply is primed with <im_start>assistant
|
||||||
|
return num_tokens
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"get_num_tokens_from_messages() is not presently implemented "
|
||||||
|
f"for model {model}."
|
||||||
|
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||||
|
"information on how messages are converted to tokens."
|
||||||
|
)
|
||||||
|
@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional
|
|||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
|
from langchain.schema import get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
class ConversationBufferMemory(BaseChatMemory, BaseModel):
|
class ConversationBufferMemory(BaseChatMemory, BaseModel):
|
||||||
|
@ -3,8 +3,7 @@ from typing import Any, Dict, List
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.utils import get_buffer_string
|
from langchain.schema import BaseMessage, get_buffer_string
|
||||||
from langchain.schema import BaseMessage
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationBufferWindowMemory(BaseChatMemory, BaseModel):
|
class ConversationBufferWindowMemory(BaseChatMemory, BaseModel):
|
||||||
|
@ -8,9 +8,9 @@ from langchain.memory.prompt import (
|
|||||||
ENTITY_EXTRACTION_PROMPT,
|
ENTITY_EXTRACTION_PROMPT,
|
||||||
ENTITY_SUMMARIZATION_PROMPT,
|
ENTITY_SUMMARIZATION_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, BaseMessage
|
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
||||||
|
@ -10,9 +10,9 @@ from langchain.memory.prompt import (
|
|||||||
ENTITY_EXTRACTION_PROMPT,
|
ENTITY_EXTRACTION_PROMPT,
|
||||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, SystemMessage
|
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||||
|
@ -5,9 +5,13 @@ from pydantic import BaseModel, root_validator
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||||
from langchain.memory.utils import get_buffer_string
|
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
|
from langchain.schema import (
|
||||||
|
BaseLanguageModel,
|
||||||
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SummarizerMixin(BaseModel):
|
class SummarizerMixin(BaseModel):
|
||||||
|
@ -4,8 +4,7 @@ from pydantic import BaseModel, root_validator
|
|||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.summary import SummarizerMixin
|
from langchain.memory.summary import SummarizerMixin
|
||||||
from langchain.memory.utils import get_buffer_string
|
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string
|
||||||
from langchain.schema import BaseMessage, SystemMessage
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
||||||
@ -55,21 +54,17 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def get_num_tokens_list(self, arr: List[BaseMessage]) -> List[int]:
|
|
||||||
"""Get list of number of tokens in each string in the input array."""
|
|
||||||
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
|
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save context from this conversation to buffer."""
|
"""Save context from this conversation to buffer."""
|
||||||
super().save_context(inputs, outputs)
|
super().save_context(inputs, outputs)
|
||||||
# Prune buffer if it exceeds max token limit
|
# Prune buffer if it exceeds max token limit
|
||||||
buffer = self.chat_memory.messages
|
buffer = self.chat_memory.messages
|
||||||
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||||
if curr_buffer_length > self.max_token_limit:
|
if curr_buffer_length > self.max_token_limit:
|
||||||
pruned_memory = []
|
pruned_memory = []
|
||||||
while curr_buffer_length > self.max_token_limit:
|
while curr_buffer_length > self.max_token_limit:
|
||||||
pruned_memory.append(buffer.pop(0))
|
pruned_memory.append(buffer.pop(0))
|
||||||
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||||
self.moving_summary_buffer = self.predict_new_summary(
|
self.moving_summary_buffer = self.predict_new_summary(
|
||||||
pruned_memory, self.moving_summary_buffer
|
pruned_memory, self.moving_summary_buffer
|
||||||
)
|
)
|
||||||
|
@ -1,32 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import get_buffer_string # noqa: 401
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_string(
|
|
||||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
||||||
) -> str:
|
|
||||||
"""Get buffer string of messages."""
|
|
||||||
string_messages = []
|
|
||||||
for m in messages:
|
|
||||||
if isinstance(m, HumanMessage):
|
|
||||||
role = human_prefix
|
|
||||||
elif isinstance(m, AIMessage):
|
|
||||||
role = ai_prefix
|
|
||||||
elif isinstance(m, SystemMessage):
|
|
||||||
role = "System"
|
|
||||||
elif isinstance(m, ChatMessage):
|
|
||||||
role = m.role
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unsupported message type: {m}")
|
|
||||||
string_messages.append(f"{role}: {m.content}")
|
|
||||||
return "\n".join(string_messages)
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||||
|
@ -7,6 +7,26 @@ from typing import Any, Dict, List, NamedTuple, Optional
|
|||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
|
def get_buffer_string(
|
||||||
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||||
|
) -> str:
|
||||||
|
"""Get buffer string of messages."""
|
||||||
|
string_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, HumanMessage):
|
||||||
|
role = human_prefix
|
||||||
|
elif isinstance(m, AIMessage):
|
||||||
|
role = ai_prefix
|
||||||
|
elif isinstance(m, SystemMessage):
|
||||||
|
role = "System"
|
||||||
|
elif isinstance(m, ChatMessage):
|
||||||
|
role = m.role
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unsupported message type: {m}")
|
||||||
|
string_messages.append(f"{role}: {m.content}")
|
||||||
|
return "\n".join(string_messages)
|
||||||
|
|
||||||
|
|
||||||
class AgentAction(NamedTuple):
|
class AgentAction(NamedTuple):
|
||||||
"""Agent's action to take."""
|
"""Agent's action to take."""
|
||||||
|
|
||||||
@ -185,6 +205,10 @@ class BaseLanguageModel(BaseModel, ABC):
|
|||||||
# calculate the number of tokens in the tokenized text
|
# calculate the number of tokens in the tokenized text
|
||||||
return len(tokenized_text)
|
return len(tokenized_text)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""Get the number of tokens in the message."""
|
||||||
|
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||||
|
|
||||||
|
|
||||||
class BaseMemory(BaseModel, ABC):
|
class BaseMemory(BaseModel, ABC):
|
||||||
"""Base interface for memory in chains."""
|
"""Base interface for memory in chains."""
|
||||||
|
Loading…
Reference in New Issue
Block a user