Vwp/fix vectorstore typing (#3851)

Co-authored-by: Jay Stakelon <stakes@users.noreply.github.com>
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent fbbdf161cd
commit b1d69d3e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.llms.base import BaseLLM from langchain.base_language import BaseLanguageModel
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.vectorstore.tool import ( from langchain.tools.vectorstore.tool import (
@ -31,7 +31,7 @@ class VectorStoreToolkit(BaseToolkit):
"""Toolkit for interacting with a vector store.""" """Toolkit for interacting with a vector store."""
vectorstore_info: VectorStoreInfo = Field(exclude=True) vectorstore_info: VectorStoreInfo = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -65,7 +65,7 @@ class VectorStoreRouterToolkit(BaseToolkit):
"""Toolkit for routing between vectorstores.""" """Toolkit for routing between vectorstores."""
vectorstores: List[VectorStoreInfo] = Field(exclude=True) vectorstores: List[VectorStoreInfo] = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

@ -5,12 +5,12 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -20,7 +20,7 @@ class BaseVectorStoreTool(BaseModel):
"""Base class for tools that use a VectorStore.""" """Base class for tools that use a VectorStore."""
vectorstore: VectorStore = Field(exclude=True) vectorstore: VectorStore = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config(BaseTool.Config): class Config(BaseTool.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

Loading…
Cancel
Save