mv base cache to schema (#9953)

if you remove all other imports from langchain.init it exposes a
circular dep
This commit is contained in:
Bagatur 2023-08-30 08:10:51 -07:00 committed by GitHub
parent 9870bfb9cd
commit 9828701de1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 23 deletions

View File

@ -4,7 +4,6 @@ from importlib import metadata
from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -40,6 +39,7 @@ from langchain.prompts import (
Prompt,
PromptTemplate,
)
from langchain.schema.cache import BaseCache
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.golden_query import GoldenQueryAPIWrapper

View File

@ -26,7 +26,6 @@ import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import (
TYPE_CHECKING,
@ -35,7 +34,6 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
@ -46,17 +44,18 @@ from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
from langchain.utils import get_from_env
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.schema import ChatGeneration, Generation
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
from langchain.utils import get_from_env
from langchain.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__)
@ -64,8 +63,6 @@ logger = logging.getLogger(__file__)
if TYPE_CHECKING:
import momento
RETURN_VAL_TYPE = Sequence[Generation]
def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
@ -105,22 +102,6 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
)
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
class InMemoryCache(BaseCache):
"""Cache that stores things in memory."""

View File

@ -1,5 +1,6 @@
"""**Schemas** are the LangChain Base Classes and Interfaces."""
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.cache import BaseCache
from langchain.schema.chat_history import BaseChatMessageHistory
from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.exceptions import LangChainException
@ -39,6 +40,7 @@ RUN_KEY = "__run"
Memory = BaseMemory
__all__ = [
"BaseCache",
"BaseMemory",
"BaseStore",
"AgentFinish",

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain.schema.output import Generation
RETURN_VAL_TYPE = Sequence[Generation]
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""