From b1d69d3e7ac1f2f5ee43a0e2f1d68eeba75ef48f Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Sun, 30 Apr 2023 16:45:10 -0700 Subject: [PATCH] Vwp/fix vectorstore typing (#3851) Co-authored-by: Jay Stakelon --- langchain/agents/agent_toolkits/vectorstore/toolkit.py | 6 +++--- langchain/tools/vectorstore/tool.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/langchain/agents/agent_toolkits/vectorstore/toolkit.py index 72ad1262..22002a17 100644 --- a/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -4,7 +4,7 @@ from typing import List from pydantic import BaseModel, Field from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.llms.base import BaseLLM +from langchain.base_language import BaseLanguageModel from langchain.llms.openai import OpenAI from langchain.tools import BaseTool from langchain.tools.vectorstore.tool import ( @@ -31,7 +31,7 @@ class VectorStoreToolkit(BaseToolkit): """Toolkit for interacting with a vector store.""" vectorstore_info: VectorStoreInfo = Field(exclude=True) - llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) + llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) class Config: """Configuration for this pydantic object.""" @@ -65,7 +65,7 @@ class VectorStoreRouterToolkit(BaseToolkit): """Toolkit for routing between vectorstores.""" vectorstores: List[VectorStoreInfo] = Field(exclude=True) - llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) + llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) class Config: """Configuration for this pydantic object.""" diff --git a/langchain/tools/vectorstore/tool.py b/langchain/tools/vectorstore/tool.py index 983224b4..c5fae604 100644 --- a/langchain/tools/vectorstore/tool.py +++ b/langchain/tools/vectorstore/tool.py @@ -5,12 +5,12 @@ from typing import Any, Dict, Optional from pydantic import BaseModel, Field +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain -from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI from langchain.tools.base import BaseTool from langchain.vectorstores.base import VectorStore @@ -20,7 +20,7 @@ class BaseVectorStoreTool(BaseModel): """Base class for tools that use a VectorStore.""" vectorstore: VectorStore = Field(exclude=True) - llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) + llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) class Config(BaseTool.Config): """Configuration for this pydantic object."""