2023-04-19 04:41:03 +00:00
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
2023-09-25 19:44:23 +00:00
|
|
|
from langchain.schema.vectorstore import VectorStoreRetriever
|
Use a submodule for pydantic v1 compat (#9371)
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
- Description: a description of the change,
- Issue: the issue # it fixes (if applicable),
- Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
-->
2023-08-17 15:35:49 +00:00
|
|
|
|
|
|
|
from langchain_experimental.pydantic_v1 import Field
|
2023-04-19 04:41:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
class AutoGPTMemory(BaseChatMemory):
|
2023-07-28 20:01:36 +00:00
|
|
|
"""Memory for AutoGPT."""
|
|
|
|
|
2023-04-19 04:41:03 +00:00
|
|
|
retriever: VectorStoreRetriever = Field(exclude=True)
|
|
|
|
"""VectorStoreRetriever object to connect to."""
|
|
|
|
|
|
|
|
@property
|
|
|
|
def memory_variables(self) -> List[str]:
|
|
|
|
return ["chat_history", "relevant_context"]
|
|
|
|
|
|
|
|
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
|
|
|
"""Get the input key for the prompt."""
|
|
|
|
if self.input_key is None:
|
|
|
|
return get_prompt_input_key(inputs, self.memory_variables)
|
|
|
|
return self.input_key
|
|
|
|
|
|
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
input_key = self._get_prompt_input_key(inputs)
|
|
|
|
query = inputs[input_key]
|
|
|
|
docs = self.retriever.get_relevant_documents(query)
|
|
|
|
return {
|
|
|
|
"chat_history": self.chat_memory.messages[-10:],
|
|
|
|
"relevant_context": docs,
|
|
|
|
}
|