|
|
@ -5,12 +5,12 @@ from typing import Any, Dict, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
AsyncCallbackManagerForToolRun,
|
|
|
|
AsyncCallbackManagerForToolRun,
|
|
|
|
CallbackManagerForToolRun,
|
|
|
|
CallbackManagerForToolRun,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
|
|
|
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
@ -20,7 +20,7 @@ class BaseVectorStoreTool(BaseModel):
|
|
|
|
"""Base class for tools that use a VectorStore."""
|
|
|
|
"""Base class for tools that use a VectorStore."""
|
|
|
|
|
|
|
|
|
|
|
|
vectorstore: VectorStore = Field(exclude=True)
|
|
|
|
vectorstore: VectorStore = Field(exclude=True)
|
|
|
|
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
|
|
|
|
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
|
|
|
|
|
|
|
|
|
|
|
class Config(BaseTool.Config):
|
|
|
|
class Config(BaseTool.Config):
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|