forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
|
|
|
from pydantic import Extra
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
from langchain.chains.router.base import RouterChain
|
|
from langchain.docstore.document import Document
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
class EmbeddingRouterChain(RouterChain):
|
|
"""Class that uses embeddings to route between options."""
|
|
|
|
vectorstore: VectorStore
|
|
routing_keys: List[str] = ["query"]
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Will be whatever keys the LLM chain prompt expects.
|
|
|
|
:meta private:
|
|
"""
|
|
return self.routing_keys
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
|
results = self.vectorstore.similarity_search(_input, k=1)
|
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
|
|
|
@classmethod
|
|
def from_names_and_descriptions(
|
|
cls,
|
|
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
|
vectorstore_cls: Type[VectorStore],
|
|
embeddings: Embeddings,
|
|
**kwargs: Any,
|
|
) -> EmbeddingRouterChain:
|
|
"""Convenience constructor."""
|
|
documents = []
|
|
for name, descriptions in names_and_descriptions:
|
|
for description in descriptions:
|
|
documents.append(
|
|
Document(page_content=description, metadata={"name": name})
|
|
)
|
|
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
|
|
return cls(vectorstore=vectorstore, **kwargs)
|