You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/router/embedding_router.py

60 lines
1.9 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.router.base import RouterChain
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
class EmbeddingRouterChain(RouterChain):
"""Class that uses embeddings to route between options."""
vectorstore: VectorStore
routing_keys: List[str] = ["query"]
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.routing_keys
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys])
results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@classmethod
def from_names_and_descriptions(
cls,
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
vectorstore_cls: Type[VectorStore],
embeddings: Embeddings,
**kwargs: Any,
) -> EmbeddingRouterChain:
"""Convenience constructor."""
documents = []
for name, descriptions in names_and_descriptions:
for description in descriptions:
documents.append(
Document(page_content=description, metadata={"name": name})
)
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
return cls(vectorstore=vectorstore, **kwargs)