Removed duplicate BaseModel dependencies (#2471)

Removed duplicate BaseModel dependencies in class inheritances.
Also, sorted imports by `isort`.
doc
leo-gan 1 year ago committed by GitHub
parent b6a101d121
commit fd69cc7e42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -548,7 +548,7 @@ class Agent(BaseSingleActionAgent):
} }
class AgentExecutor(Chain, BaseModel): class AgentExecutor(Chain):
"""Consists of an agent using tools.""" """Consists of an agent using tools."""
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]

@ -19,9 +19,7 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
REQUESTS_GET_TOOL_DESCRIPTION, REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION, REQUESTS_POST_TOOL_DESCRIPTION,
) )
from langchain.agents.agent_toolkits.openapi.spec import ( from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
ReducedOpenAPISpec,
)
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain

@ -2,7 +2,6 @@
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API. API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
You should: You should:

@ -1,11 +1,11 @@
# flake8: noqa # flake8: noqa
"""Load tools.""" """Load tools."""
from typing import Any, List, Optional
import warnings import warnings
from typing import Any, List, Optional
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.api import news_docs, open_meteo_docs, tmdb_docs, podcast_docs from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain from langchain.chains.pal.base import PALChain
@ -14,16 +14,16 @@ from langchain.requests import TextRequestsWrapper
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.bing_search.tool import BingSearchRun from langchain.tools.bing_search.tool import BingSearchRun
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
from langchain.tools.human.tool import HumanInputRun from langchain.tools.human.tool import HumanInputRun
from langchain.tools.python.tool import PythonREPLTool from langchain.tools.python.tool import PythonREPLTool
from langchain.tools.requests.tool import ( from langchain.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool, RequestsGetTool,
RequestsPostTool,
RequestsPatchTool, RequestsPatchTool,
RequestsPostTool,
RequestsPutTool, RequestsPutTool,
RequestsDeleteTool,
) )
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wikipedia.tool import WikipediaQueryRun
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
from langchain.utilities.apify import ApifyWrapper from langchain.utilities.apify import ApifyWrapper

@ -2,8 +2,6 @@
import re import re
from typing import Any, List, Optional, Sequence, Tuple from typing import Any, List, Optional, Sequence, Tuple
from pydantic import BaseModel
from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.agent_types import AgentType from langchain.agents.agent_types import AgentType
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
@ -16,7 +14,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
class ReActDocstoreAgent(Agent, BaseModel): class ReActDocstoreAgent(Agent):
"""Agent for the ReAct chain.""" """Agent for the ReAct chain."""
@property @property
@ -124,7 +122,7 @@ class DocstoreExplorer:
return self.document.page_content.split("\n\n") return self.document.page_content.split("\n\n")
class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel): class ReActTextWorldAgent(ReActDocstoreAgent):
"""Agent for the ReAct TextWorld chain.""" """Agent for the ReAct TextWorld chain."""
@classmethod @classmethod

@ -77,11 +77,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log data to Arize when an LLM ends.""" """Log data to Arize when an LLM ends."""
from arize.utils.types import ( from arize.utils.types import Embedding, Environments, ModelTypes
Embedding,
Environments,
ModelTypes,
)
# Record token usage of the LLM # Record token usage of the LLM
if response.llm_output is not None: if response.llm_output is not None:

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, root_validator from pydantic import Field, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -13,7 +13,7 @@ from langchain.requests import TextRequestsWrapper
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel
class APIChain(Chain, BaseModel): class APIChain(Chain):
"""Chain that makes API calls and summarizes the responses to answer a question.""" """Chain that makes API calls and summarizes the responses to answer a question."""
api_request_chain: LLMChain api_request_chain: LLMChain

@ -3,14 +3,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field from pydantic import Field
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
class BaseCombineDocumentsChain(Chain, BaseModel, ABC): class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents.""" """Base interface for chains combining documents."""
input_key: str = "input_documents" #: :meta private: input_key: str = "input_documents" #: :meta private:
@ -66,7 +66,7 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
return extra_return_dict return extra_return_dict
class AnalyzeDocumentChain(Chain, BaseModel): class AnalyzeDocumentChain(Chain):
"""Chain that splits documents, then analyzes it in pieces.""" """Chain that splits documents, then analyzes it in pieces."""
input_key: str = "input_document" #: :meta private: input_key: str = "input_document" #: :meta private:

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -59,7 +59,7 @@ def _collapse_docs(
return Document(page_content=result, metadata=combined_metadata) return Document(page_content=result, metadata=combined_metadata)
class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then combining results.""" """Combining documents by mapping a chain over them, then combining results."""
llm_chain: LLMChain llm_chain: LLMChain

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -12,7 +12,7 @@ from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then reranking results.""" """Combining documents by mapping a chain over them, then reranking results."""
llm_chain: LLMChain llm_chain: LLMChain

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -17,7 +17,7 @@ def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}") return PromptTemplate(input_variables=["page_content"], template="{page_content}")
class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents by doing a first pass and then refining on more documents.""" """Combine documents by doing a first pass and then refining on more documents."""
initial_llm_chain: LLMChain initial_llm_chain: LLMChain

@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -15,7 +15,7 @@ def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}") return PromptTemplate(input_variables=["page_content"], template="{page_content}")
class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context.""" """Chain that combines documents by stuffing into context."""
llm_chain: LLMChain llm_chain: LLMChain

@ -1,5 +1,6 @@
# flake8: noqa # flake8: noqa
from typing import Dict from typing import Dict
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
PRINCIPLES: Dict[str, ConstitutionalPrinciple] = {} PRINCIPLES: Dict[str, ConstitutionalPrinciple] = {}

@ -1,7 +1,7 @@
"""Chain that carries on a conversation and calls an LLM.""" """Chain that carries on a conversation and calls an LLM."""
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.chains.conversation.prompt import PROMPT from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -10,7 +10,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMemory from langchain.schema import BaseMemory
class ConversationChain(LLMChain, BaseModel): class ConversationChain(LLMChain):
"""Chain to have a conversation and load context from memory. """Chain to have a conversation and load context from memory.
Example: Example:

@ -6,7 +6,7 @@ from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -28,7 +28,7 @@ def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str:
return buffer return buffer
class BaseConversationalRetrievalChain(Chain, BaseModel): class BaseConversationalRetrievalChain(Chain):
"""Chain for chatting with an index.""" """Chain for chatting with an index."""
combine_docs_chain: BaseCombineDocumentsChain combine_docs_chain: BaseCombineDocumentsChain
@ -116,7 +116,7 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
super().save(file_path) super().save(file_path)
class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""Chain for chatting with an index.""" """Chain for chatting with an index."""
retriever: BaseRetriever retriever: BaseRetriever
@ -175,7 +175,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
) )
class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel): class ChatVectorDBChain(BaseConversationalRetrievalChain):
"""Chain for chatting with a vector database.""" """Chain for chatting with a vector database."""
vectorstore: VectorStore = Field(alias="vectorstore") vectorstore: VectorStore = Field(alias="vectorstore")

@ -7,7 +7,7 @@ from __future__ import annotations
from typing import Dict, List from typing import Dict, List
import numpy as np import numpy as np
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.hyde.prompts import PROMPT_MAP
@ -16,7 +16,7 @@ from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel): class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Generate hypothetical document for query, and then embed that. """Generate hypothetical document for query, and then embed that.
Based on https://arxiv.org/abs/2212.10496 Based on https://arxiv.org/abs/2212.10496

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_colored_text from langchain.input import get_colored_text
@ -12,7 +12,7 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
class LLMChain(Chain, BaseModel): class LLMChain(Chain):
"""Chain to run queries against LLMs. """Chain to run queries against LLMs.
Example: Example:

@ -1,7 +1,7 @@
"""Chain that interprets a prompt and executes bash code to perform bash operations.""" """Chain that interprets a prompt and executes bash code to perform bash operations."""
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -11,7 +11,7 @@ from langchain.schema import BaseLanguageModel
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
class LLMBashChain(Chain, BaseModel): class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash code to perform bash operations. """Chain that interprets a prompt and executes bash code to perform bash operations.
Example: Example:

@ -3,7 +3,7 @@
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -18,7 +18,7 @@ from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
class LLMCheckerChain(Chain, BaseModel): class LLMCheckerChain(Chain):
"""Chain for question-answering with self-verification. """Chain for question-answering with self-verification.
Example: Example:

@ -1,7 +1,7 @@
"""Chain that interprets a prompt and executes python code to do math.""" """Chain that interprets a prompt and executes python code to do math."""
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -11,7 +11,7 @@ from langchain.python import PythonREPL
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel
class LLMMathChain(Chain, BaseModel): class LLMMathChain(Chain):
"""Chain that interprets a prompt and executes python code to do math. """Chain that interprets a prompt and executes python code to do math.
Example: Example:

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -14,7 +14,7 @@ DEFAULT_HEADERS = {
} }
class LLMRequestsChain(Chain, BaseModel): class LLMRequestsChain(Chain):
"""Chain that hits a URL and then uses an LLM to parse results.""" """Chain that hits a URL and then uses an LLM to parse results."""
llm_chain: LLMChain llm_chain: LLMChain

@ -3,7 +3,7 @@
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -27,7 +27,7 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
) )
class LLMSummarizationCheckerChain(Chain, BaseModel): class LLMSummarizationCheckerChain(Chain):
"""Chain for question-answering with self-verification. """Chain for question-answering with self-verification.
Example: Example:

@ -7,7 +7,7 @@ from __future__ import annotations
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -20,7 +20,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
class MapReduceChain(Chain, BaseModel): class MapReduceChain(Chain):
"""Map-reduce chain.""" """Map-reduce chain."""
combine_documents_chain: BaseCombineDocumentsChain combine_documents_chain: BaseCombineDocumentsChain

@ -1,13 +1,13 @@
"""Pass input through a moderation endpoint.""" """Pass input through a moderation endpoint."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator from pydantic import root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class OpenAIModerationChain(Chain, BaseModel): class OpenAIModerationChain(Chain):
"""Pass input through a moderation endpoint. """Pass input through a moderation endpoint.
To use, you should have the ``openai`` python package installed, and the To use, you should have the ``openai`` python package installed, and the

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -12,7 +12,7 @@ from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
class NatBotChain(Chain, BaseModel): class NatBotChain(Chain):
"""Implement an LLM driven browser. """Implement an LLM driven browser.
Example: Example:

@ -6,7 +6,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -17,7 +17,7 @@ from langchain.python import PythonREPL
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel
class PALChain(Chain, BaseModel): class PALChain(Chain):
"""Implements Program-Aided Language Models.""" """Implements Program-Aided Language Models."""
llm: BaseLanguageModel llm: BaseLanguageModel

@ -6,7 +6,7 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -24,7 +24,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, BaseModel, ABC): class BaseQAWithSourcesChain(Chain, ABC):
"""Question answering with sources over documents.""" """Question answering with sources over documents."""
combine_documents_chain: BaseCombineDocumentsChain combine_documents_chain: BaseCombineDocumentsChain
@ -149,7 +149,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
return result return result
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): class QAWithSourcesChain(BaseQAWithSourcesChain):
"""Question answering with sources over documents.""" """Question answering with sources over documents."""
input_docs_key: str = "docs" #: :meta private: input_docs_key: str = "docs" #: :meta private:

@ -2,7 +2,7 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Field from pydantic import Field
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
@ -10,7 +10,7 @@ from langchain.docstore.document import Document
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
"""Question-answering with sources over an index.""" """Question-answering with sources over an index."""
retriever: BaseRetriever = Field(exclude=True) retriever: BaseRetriever = Field(exclude=True)

@ -3,7 +3,7 @@
import warnings import warnings
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Field, root_validator from pydantic import Field, root_validator
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
@ -11,7 +11,7 @@ from langchain.docstore.document import Document
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
"""Question-answering with sources over a vector database.""" """Question-answering with sources over a vector database."""
vectorstore: VectorStore = Field(exclude=True) vectorstore: VectorStore = Field(exclude=True)

@ -1,14 +1,11 @@
# flake8: noqa # flake8: noqa
from langchain.prompts.prompt import PromptTemplate from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.prompts.chat import ( from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
) )
from langchain.chains.prompt_selector import ( from langchain.prompts.prompt import PromptTemplate
ConditionalPromptSelector,
is_chat_model,
)
question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim. Return any relevant text verbatim.

@ -1,6 +1,6 @@
# flake8: noqa # flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.prompts import PromptTemplate
output_parser = RegexParser( output_parser = RegexParser(
regex=r"(.*?)\nScore: (.*)", regex=r"(.*?)\nScore: (.*)",

@ -1,16 +1,12 @@
# flake8: noqa # flake8: noqa
from langchain.prompts.prompt import PromptTemplate from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.prompts.chat import ( from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
AIMessagePromptTemplate, AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
) )
from langchain.chains.prompt_selector import ( from langchain.prompts.prompt import PromptTemplate
ConditionalPromptSelector,
is_chat_model,
)
DEFAULT_REFINE_PROMPT_TMPL = ( DEFAULT_REFINE_PROMPT_TMPL = (
"The original question is as follows: {question}\n" "The original question is as follows: {question}\n"

@ -1,16 +1,12 @@
# flake8: noqa # flake8: noqa
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)
from langchain.prompts.chat import ( from langchain.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
) )
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context} {context}

@ -5,7 +5,7 @@ import warnings
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -18,7 +18,7 @@ from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
class BaseRetrievalQA(Chain, BaseModel): class BaseRetrievalQA(Chain):
combine_documents_chain: BaseCombineDocumentsChain combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents.""" """Chain to use to combine the documents."""
input_key: str = "query" #: :meta private: input_key: str = "query" #: :meta private:
@ -143,7 +143,7 @@ class BaseRetrievalQA(Chain, BaseModel):
return {self.output_key: answer} return {self.output_key: answer}
class RetrievalQA(BaseRetrievalQA, BaseModel): class RetrievalQA(BaseRetrievalQA):
"""Chain for question-answering against an index. """Chain for question-answering against an index.
Example: Example:
@ -166,7 +166,7 @@ class RetrievalQA(BaseRetrievalQA, BaseModel):
return await self.retriever.aget_relevant_documents(question) return await self.retriever.aget_relevant_documents(question)
class VectorDBQA(BaseRetrievalQA, BaseModel): class VectorDBQA(BaseRetrievalQA):
"""Chain for question-answering against a vector database.""" """Chain for question-answering against a vector database."""
vectorstore: VectorStore = Field(exclude=True, alias="vectorstore") vectorstore: VectorStore = Field(exclude=True, alias="vectorstore")

@ -1,13 +1,13 @@
"""Chain pipeline where the outputs of one step feed directly into next.""" """Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_color_mapping from langchain.input import get_color_mapping
class SequentialChain(Chain, BaseModel): class SequentialChain(Chain):
"""Chain where the outputs of one chain feed directly into next.""" """Chain where the outputs of one chain feed directly into next."""
chains: List[Chain] chains: List[Chain]
@ -94,7 +94,7 @@ class SequentialChain(Chain, BaseModel):
return {k: known_values[k] for k in self.output_variables} return {k: known_values[k] for k in self.output_variables}
class SimpleSequentialChain(Chain, BaseModel): class SimpleSequentialChain(Chain):
"""Simple chain where the outputs of one step feed directly into next.""" """Simple chain where the outputs of one step feed directly into next."""
chains: List[Chain] chains: List[Chain]

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Extra, Field from pydantic import Extra, Field
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -13,7 +13,7 @@ from langchain.schema import BaseLanguageModel
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
class SQLDatabaseChain(Chain, BaseModel): class SQLDatabaseChain(Chain):
"""Chain for interacting with SQL Database. """Chain for interacting with SQL Database.
Example: Example:
@ -107,7 +107,7 @@ class SQLDatabaseChain(Chain, BaseModel):
return "sql_database_chain" return "sql_database_chain"
class SQLDatabaseSequentialChain(Chain, BaseModel): class SQLDatabaseSequentialChain(Chain):
"""Chain for querying SQL database that is a sequential chain. """Chain for querying SQL database that is a sequential chain.
The chain is as follows: The chain is as follows:

@ -1,12 +1,10 @@
"""Chain that runs an arbitrary python function.""" """Chain that runs an arbitrary python function."""
from typing import Callable, Dict, List from typing import Callable, Dict, List
from pydantic import BaseModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
class TransformChain(Chain, BaseModel): class TransformChain(Chain):
"""Chain transform chain output. """Chain transform chain output.
Example: Example:

@ -6,9 +6,7 @@ from typing import Any, Dict
from pydantic import root_validator from pydantic import root_validator
from langchain.chat_models.openai import ( from langchain.chat_models.openai import ChatOpenAI
ChatOpenAI,
)
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)

@ -2,7 +2,7 @@ import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Extra, Field, validator from pydantic import Extra, Field, validator
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
@ -23,7 +23,7 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
class BaseChatModel(BaseLanguageModel, BaseModel, ABC): class BaseChatModel(BaseLanguageModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity) verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text.""" """Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)

@ -5,7 +5,7 @@ import logging
import sys import sys
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@ -91,7 +91,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict return message_dict
class ChatOpenAI(BaseChatModel, BaseModel): class ChatOpenAI(BaseChatModel):
"""Wrapper around OpenAI Chat large language models. """Wrapper around OpenAI Chat large language models.
To use, you should have the ``openai`` python package installed, and the To use, you should have the ``openai`` python package installed, and the

@ -2,13 +2,11 @@
import datetime import datetime
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatResult from langchain.schema import BaseMessage, ChatResult
class PromptLayerChatOpenAI(ChatOpenAI, BaseModel): class PromptLayerChatOpenAI(ChatOpenAI):
"""Wrapper around OpenAI Chat large language models and PromptLayer. """Wrapper around OpenAI Chat large language models and PromptLayer.
To use, you should have the ``openai`` and ``promptlayer`` python To use, you should have the ``openai`` and ``promptlayer`` python

@ -11,9 +11,7 @@ from langchain.document_loaders.azure_blob_storage_file import (
) )
from langchain.document_loaders.bigquery import BigQueryLoader from langchain.document_loaders.bigquery import BigQueryLoader
from langchain.document_loaders.blackboard import BlackboardLoader from langchain.document_loaders.blackboard import BlackboardLoader
from langchain.document_loaders.college_confidential import ( from langchain.document_loaders.college_confidential import CollegeConfidentialLoader
CollegeConfidentialLoader,
)
from langchain.document_loaders.conllu import CoNLLULoader from langchain.document_loaders.conllu import CoNLLULoader
from langchain.document_loaders.csv_loader import CSVLoader from langchain.document_loaders.csv_loader import CSVLoader
from langchain.document_loaders.dataframe import DataFrameLoader from langchain.document_loaders.dataframe import DataFrameLoader
@ -66,9 +64,7 @@ from langchain.document_loaders.url import UnstructuredURLLoader
from langchain.document_loaders.url_selenium import SeleniumURLLoader from langchain.document_loaders.url_selenium import SeleniumURLLoader
from langchain.document_loaders.web_base import WebBaseLoader from langchain.document_loaders.web_base import WebBaseLoader
from langchain.document_loaders.whatsapp_chat import WhatsAppChatLoader from langchain.document_loaders.whatsapp_chat import WhatsAppChatLoader
from langchain.document_loaders.word_document import ( from langchain.document_loaders.word_document import UnstructuredWordDocumentLoader
UnstructuredWordDocumentLoader,
)
from langchain.document_loaders.youtube import ( from langchain.document_loaders.youtube import (
GoogleApiClient, GoogleApiClient,
GoogleApiYoutubeLoader, GoogleApiYoutubeLoader,

@ -54,9 +54,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY" values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY"
) )
try: try:
from aleph_alpha_client import ( from aleph_alpha_client import Client
Client,
)
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import aleph_alpha_client python package. " "Could not import aleph_alpha_client python package. "

@ -1,7 +1,7 @@
"""Running custom embedding models on self-hosted remote hardware.""" """Running custom embedding models on self-hosted remote hardware."""
from typing import Any, Callable, List from typing import Any, Callable, List
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.llms import SelfHostedPipeline from langchain.llms import SelfHostedPipeline
@ -16,7 +16,7 @@ def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[floa
return pipeline(*args, **kwargs) return pipeline(*args, **kwargs)
class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings, BaseModel): class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
"""Runs custom embedding models on self-hosted remote hardware. """Runs custom embedding models on self-hosted remote hardware.
Supported hardware includes auto-launched instances on AWS, GCP, Azure, Supported hardware includes auto-launched instances on AWS, GCP, Azure,

@ -3,8 +3,6 @@ import importlib
import logging import logging
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
from pydantic import BaseModel
from langchain.embeddings.self_hosted import SelfHostedEmbeddings from langchain.embeddings.self_hosted import SelfHostedEmbeddings
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
@ -59,7 +57,7 @@ def load_embedding_model(model_id: str, instruct: bool = False, device: int = 0)
return client return client
class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings, BaseModel): class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings):
"""Runs sentence_transformers embedding models on self-hosted remote hardware. """Runs sentence_transformers embedding models on self-hosted remote hardware.
Supported hardware includes auto-launched instances on AWS, GCP, Azure, Supported hardware includes auto-launched instances on AWS, GCP, Azure,

@ -1,6 +1,6 @@
# flake8: noqa # flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.prompts import PromptTemplate
template = """You are a teacher coming up with questions to ask on a quiz. template = """You are a teacher coming up with questions to ask on a quiz.
Given the following document, please generate a question and answer based on that document. Given the following document, please generate a question and answer based on that document.

@ -19,7 +19,7 @@ class AI21PenaltyData(BaseModel):
applyToEmojis: bool = True applyToEmojis: bool = True
class AI21(LLM, BaseModel): class AI21(LLM):
"""Wrapper around AI21 large language models. """Wrapper around AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY`` To use, you should have the environment variable ``AI21_API_KEY``

@ -1,14 +1,14 @@
"""Wrapper around Aleph Alpha APIs.""" """Wrapper around Aleph Alpha APIs."""
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class AlephAlpha(LLM, BaseModel): class AlephAlpha(LLM):
"""Wrapper around Aleph Alpha large language models. """Wrapper around Aleph Alpha large language models.
To use, you should have the ``aleph_alpha_client`` python package installed, and the To use, you should have the ``aleph_alpha_client`` python package installed, and the

@ -2,13 +2,13 @@
import re import re
from typing import Any, Dict, Generator, List, Mapping, Optional from typing import Any, Dict, Generator, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class Anthropic(LLM, BaseModel): class Anthropic(LLM):
r"""Wrapper around Anthropic large language models. r"""Wrapper around Anthropic large language models.
To use, you should have the ``anthropic`` python package installed, and the To use, you should have the ``anthropic`` python package installed, and the

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Banana(LLM, BaseModel): class Banana(LLM):
"""Wrapper around Banana large language models. """Wrapper around Banana large language models.
To use, you should have the ``banana-dev`` python package installed, To use, you should have the ``banana-dev`` python package installed,

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import yaml import yaml
from pydantic import BaseModel, Extra, Field, validator from pydantic import Extra, Field, validator
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
@ -53,7 +53,7 @@ def update_cache(
return llm_output return llm_output
class BaseLLM(BaseLanguageModel, BaseModel, ABC): class BaseLLM(BaseLanguageModel, ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None cache: Optional[bool] = None

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CerebriumAI(LLM, BaseModel): class CerebriumAI(LLM):
"""Wrapper around CerebriumAI large language models. """Wrapper around CerebriumAI large language models.
To use, you should have the ``cerebrium`` python package installed, and the To use, you should have the ``cerebrium`` python package installed, and the

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Cohere(LLM, BaseModel): class Cohere(LLM):
"""Wrapper around Cohere large language models. """Wrapper around Cohere large language models.
To use, you should have the ``cohere`` python package installed, and the To use, you should have the ``cohere`` python package installed, and the

@ -2,7 +2,7 @@
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env
DEFAULT_MODEL_ID = "google/flan-t5-xl" DEFAULT_MODEL_ID = "google/flan-t5-xl"
class DeepInfra(LLM, BaseModel): class DeepInfra(LLM):
"""Wrapper around DeepInfra deployed models. """Wrapper around DeepInfra deployed models.
To use, you should have the ``requests`` python package installed, and the To use, you should have the ``requests`` python package installed, and the

@ -1,12 +1,10 @@
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.llms.base import LLM from langchain.llms.base import LLM
class FakeListLLM(LLM, BaseModel): class FakeListLLM(LLM):
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
responses: List responses: List

@ -2,14 +2,14 @@
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class ForefrontAI(LLM, BaseModel): class ForefrontAI(LLM):
"""Wrapper around ForefrontAI large language models. """Wrapper around ForefrontAI large language models.
To use, you should have the environment variable ``FOREFRONTAI_API_KEY`` To use, you should have the environment variable ``FOREFRONTAI_API_KEY``

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -10,7 +10,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GooseAI(LLM, BaseModel): class GooseAI(LLM):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
To use, you should have the ``openai`` python package installed, and the To use, you should have the ``openai`` python package installed, and the

@ -1,13 +1,13 @@
"""Wrapper for the GPT4All model.""" """Wrapper for the GPT4All model."""
from typing import Any, Dict, List, Mapping, Optional, Set from typing import Any, Dict, List, Mapping, Optional, Set
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
class GPT4All(LLM, BaseModel): class GPT4All(LLM):
r"""Wrapper around GPT4All language models. r"""Wrapper around GPT4All language models.
To use, you should have the ``pyllamacpp`` python package installed, the To use, you should have the ``pyllamacpp`` python package installed, the

@ -2,7 +2,7 @@
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
class HuggingFaceEndpoint(LLM, BaseModel): class HuggingFaceEndpoint(LLM):
"""Wrapper around HuggingFaceHub Inference Endpoints. """Wrapper around HuggingFaceHub Inference Endpoints.
To use, you should have the ``huggingface_hub`` python package installed, and the To use, you should have the ``huggingface_hub`` python package installed, and the

@ -1,7 +1,7 @@
"""Wrapper around HuggingFace APIs.""" """Wrapper around HuggingFace APIs."""
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
class HuggingFaceHub(LLM, BaseModel): class HuggingFaceHub(LLM):
"""Wrapper around HuggingFaceHub models. """Wrapper around HuggingFaceHub models.
To use, you should have the ``huggingface_hub`` python package installed, and the To use, you should have the ``huggingface_hub`` python package installed, and the

@ -3,7 +3,7 @@ import importlib.util
import logging import logging
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -15,7 +15,7 @@ VALID_TASKS = ("text2text-generation", "text-generation")
logger = logging.getLogger() logger = logging.getLogger()
class HuggingFacePipeline(LLM, BaseModel): class HuggingFacePipeline(LLM):
"""Wrapper around HuggingFace Pipeline API. """Wrapper around HuggingFace Pipeline API.
To use, you should have the ``transformers`` python package installed. To use, you should have the ``transformers`` python package installed.

@ -2,14 +2,14 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, root_validator from pydantic import Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LlamaCpp(LLM, BaseModel): class LlamaCpp(LLM):
"""Wrapper around the llama.cpp model. """Wrapper around the llama.cpp model.
To use, you should have the llama-cpp-python library installed, and provide the To use, you should have the llama-cpp-python library installed, and provide the

@ -1,12 +1,12 @@
"""Wrapper around HazyResearch's Manifest library.""" """Wrapper around HazyResearch's Manifest library."""
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
class ManifestWrapper(LLM, BaseModel): class ManifestWrapper(LLM):
"""Wrapper around HazyResearch's Manifest library.""" """Wrapper around HazyResearch's Manifest library."""
client: Any #: :meta private: client: Any #: :meta private:

@ -3,7 +3,7 @@ import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Modal(LLM, BaseModel): class Modal(LLM):
"""Wrapper around Modal large language models. """Wrapper around Modal large language models.
To use, you should have the ``modal-client`` python package installed. To use, you should have the ``modal-client`` python package installed.

@ -1,13 +1,13 @@
"""Wrapper around NLPCloud APIs.""" """Wrapper around NLPCloud APIs."""
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class NLPCloud(LLM, BaseModel): class NLPCloud(LLM):
"""Wrapper around NLPCloud large language models. """Wrapper around NLPCloud large language models.
To use, you should have the ``nlpcloud`` python package installed, and the To use, you should have the ``nlpcloud`` python package installed, and the

@ -17,7 +17,7 @@ from typing import (
Union, Union,
) )
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@ -113,7 +113,7 @@ async def acompletion_with_retry(
return await _completion_with_retry(**kwargs) return await _completion_with_retry(**kwargs)
class BaseOpenAI(BaseLLM, BaseModel): class BaseOpenAI(BaseLLM):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
To use, you should have the ``openai`` python package installed, and the To use, you should have the ``openai`` python package installed, and the
@ -534,7 +534,7 @@ class AzureOpenAI(BaseOpenAI):
return {**{"engine": self.deployment_name}, **super()._invocation_params} return {**{"engine": self.deployment_name}, **super()._invocation_params}
class OpenAIChat(BaseLLM, BaseModel): class OpenAIChat(BaseLLM):
"""Wrapper around OpenAI Chat large language models. """Wrapper around OpenAI Chat large language models.
To use, you should have the ``openai`` python package installed, and the To use, you should have the ``openai`` python package installed, and the

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Petals(LLM, BaseModel): class Petals(LLM):
"""Wrapper around Petals Bloom models. """Wrapper around Petals Bloom models.
To use, you should have the ``petals`` python package installed, and the To use, you should have the ``petals`` python package installed, and the

@ -2,13 +2,11 @@
import datetime import datetime
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel
from langchain.llms import OpenAI, OpenAIChat from langchain.llms import OpenAI, OpenAIChat
from langchain.schema import LLMResult from langchain.schema import LLMResult
class PromptLayerOpenAI(OpenAI, BaseModel): class PromptLayerOpenAI(OpenAI):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
To use, you should have the ``openai`` and ``promptlayer`` python To use, you should have the ``openai`` and ``promptlayer`` python
@ -106,7 +104,7 @@ class PromptLayerOpenAI(OpenAI, BaseModel):
return generated_responses return generated_responses
class PromptLayerOpenAIChat(OpenAIChat, BaseModel): class PromptLayerOpenAIChat(OpenAIChat):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
To use, you should have the ``openai`` and ``promptlayer`` python To use, you should have the ``openai`` and ``promptlayer`` python

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -10,7 +10,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Replicate(LLM, BaseModel): class Replicate(LLM):
"""Wrapper around Replicate models. """Wrapper around Replicate models.
To use, you should have the ``replicate`` python package installed, To use, you should have the ``replicate`` python package installed,

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Union
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -55,7 +55,7 @@ class ContentHandlerBase(ABC):
""" """
class SagemakerEndpoint(LLM, BaseModel): class SagemakerEndpoint(LLM):
"""Wrapper around custom Sagemaker Inference Endpoints. """Wrapper around custom Sagemaker Inference Endpoints.
To use, you must supply the endpoint name from your deployed To use, you must supply the endpoint name from your deployed

@ -4,7 +4,7 @@ import logging
import pickle import pickle
from typing import Any, Callable, List, Mapping, Optional from typing import Any, Callable, List, Mapping, Optional
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -61,7 +61,7 @@ def _send_pipeline_to_device(pipeline: Any, device: int) -> Any:
return pipeline return pipeline
class SelfHostedPipeline(LLM, BaseModel): class SelfHostedPipeline(LLM):
"""Run model inference on self-hosted remote hardware. """Run model inference on self-hosted remote hardware.
Supported hardware includes auto-launched instances on AWS, GCP, Azure, Supported hardware includes auto-launched instances on AWS, GCP, Azure,

@ -3,7 +3,7 @@ import importlib.util
import logging import logging
from typing import Any, Callable, List, Mapping, Optional from typing import Any, Callable, List, Mapping, Optional
from pydantic import BaseModel, Extra from pydantic import Extra
from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted import SelfHostedPipeline
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -108,7 +108,7 @@ def _load_transformer(
return pipeline return pipeline
class SelfHostedHuggingFaceLLM(SelfHostedPipeline, BaseModel): class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
"""Wrapper around HuggingFace Pipeline API to run on self-hosted remote hardware. """Wrapper around HuggingFace Pipeline API to run on self-hosted remote hardware.
Supported hardware includes auto-launched instances on AWS, GCP, Azure, Supported hardware includes auto-launched instances on AWS, GCP, Azure,

@ -4,7 +4,7 @@ import time
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
@ -13,7 +13,7 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StochasticAI(LLM, BaseModel): class StochasticAI(LLM):
"""Wrapper around StochasticAI large language models. """Wrapper around StochasticAI large language models.
To use, you should have the environment variable ``STOCHASTICAI_API_KEY`` To use, you should have the environment variable ``STOCHASTICAI_API_KEY``

@ -2,14 +2,14 @@
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class Writer(LLM, BaseModel): class Writer(LLM):
"""Wrapper around Writer large language models. """Wrapper around Writer large language models.
To use, you should have the environment variable ``WRITER_API_KEY`` To use, you should have the environment variable ``WRITER_API_KEY``

@ -1,13 +1,13 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator from pydantic import root_validator
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key from langchain.memory.utils import get_prompt_input_key
from langchain.schema import get_buffer_string from langchain.schema import get_buffer_string
class ConversationBufferMemory(BaseChatMemory, BaseModel): class ConversationBufferMemory(BaseChatMemory):
"""Buffer for storing conversation memory.""" """Buffer for storing conversation memory."""
human_prefix: str = "Human" human_prefix: str = "Human"
@ -39,7 +39,7 @@ class ConversationBufferMemory(BaseChatMemory, BaseModel):
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}
class ConversationStringBufferMemory(BaseMemory, BaseModel): class ConversationStringBufferMemory(BaseMemory):
"""Buffer for storing conversation memory.""" """Buffer for storing conversation memory."""
human_prefix: str = "Human" human_prefix: str = "Human"

@ -1,12 +1,10 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage, get_buffer_string from langchain.schema import BaseMessage, get_buffer_string
class ConversationBufferWindowMemory(BaseChatMemory, BaseModel): class ConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory.""" """Buffer for storing conversation memory."""
human_prefix: str = "Human" human_prefix: str = "Human"

@ -5,10 +5,7 @@ from pydantic import Field
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.utils import get_prompt_input_key from langchain.memory.utils import get_prompt_input_key
from langchain.schema import ( from langchain.schema import BaseChatMessageHistory, BaseMemory
BaseChatMessageHistory,
BaseMemory,
)
class BaseChatMemory(BaseMemory, ABC): class BaseChatMemory(BaseMemory, ABC):

@ -1,11 +1,9 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.schema import BaseMemory from langchain.schema import BaseMemory
class CombinedMemory(BaseMemory, BaseModel): class CombinedMemory(BaseMemory):
"""Class for combining multiple memories' data together.""" """Class for combining multiple memories' data together."""
memories: List[BaseMemory] memories: List[BaseMemory]

@ -1,7 +1,5 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import ( from langchain.memory.prompt import (
@ -13,7 +11,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
class ConversationEntityMemory(BaseChatMemory, BaseModel): class ConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer to memory.""" """Entity extractor & summarizer to memory."""
human_prefix: str = "Human" human_prefix: str = "Human"

@ -1,6 +1,6 @@
from typing import Any, Dict, List, Type, Union from typing import Any, Dict, List, Type, Union
from pydantic import BaseModel, Field from pydantic import Field
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph from langchain.graphs import NetworkxEntityGraph
@ -20,7 +20,7 @@ from langchain.schema import (
) )
class ConversationKGMemory(BaseChatMemory, BaseModel): class ConversationKGMemory(BaseChatMemory):
"""Knowledge graph memory for storing conversation memory. """Knowledge graph memory for storing conversation memory.
Integrates with external knowledge graph to store and retrieve Integrates with external knowledge graph to store and retrieve

@ -1,11 +1,9 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.schema import BaseMemory from langchain.schema import BaseMemory
class SimpleMemory(BaseMemory, BaseModel): class SimpleMemory(BaseMemory):
"""Simple memory for storing context or other bits of information that shouldn't """Simple memory for storing context or other bits of information that shouldn't
ever change between prompts. ever change between prompts.
""" """

@ -34,7 +34,7 @@ class SummarizerMixin(BaseModel):
return chain.predict(summary=existing_summary, new_lines=new_lines) return chain.predict(summary=existing_summary, new_lines=new_lines)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel): class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""Conversation summarizer to memory.""" """Conversation summarizer to memory."""
buffer: str = "" buffer: str = ""

@ -1,13 +1,13 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, root_validator from pydantic import root_validator
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
from langchain.schema import BaseMessage, get_buffer_string from langchain.schema import BaseMessage, get_buffer_string
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel): class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
"""Buffer with summarizer for storing conversation memory.""" """Buffer with summarizer for storing conversation memory."""
max_token_limit: int = 2000 max_token_limit: int = 2000

@ -1,12 +1,10 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
class ConversationTokenBufferMemory(BaseChatMemory, BaseModel): class ConversationTokenBufferMemory(BaseChatMemory):
"""Buffer for storing conversation memory.""" """Buffer for storing conversation memory."""
human_prefix: str = "Human" human_prefix: str = "Human"

@ -3,12 +3,10 @@ from __future__ import annotations
import re import re
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
class RegexParser(BaseOutputParser, BaseModel): class RegexParser(BaseOutputParser):
"""Class to parse the output into a dictionary.""" """Class to parse the output into a dictionary."""
regex: str regex: str

@ -3,12 +3,10 @@ from __future__ import annotations
import re import re
from typing import Dict, Optional from typing import Dict, Optional
from pydantic import BaseModel
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
class RegexDictParser(BaseOutputParser, BaseModel): class RegexDictParser(BaseOutputParser):
"""Class to parse the output into a dictionary.""" """Class to parse the output into a dictionary."""
regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private: regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private:

@ -20,8 +20,8 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float:
https://www.nltk.org/_modules/nltk/translate/bleu_score.html https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf https://aclanthology.org/P02-1040.pdf
""" """
from nltk.translate.bleu_score import ( # type: ignore from nltk.translate.bleu_score import (
SmoothingFunction, SmoothingFunction, # type: ignore
sentence_bleu, sentence_bleu,
) )

@ -99,7 +99,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys) return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, BaseModel): class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
"""ExampleSelector that selects examples based on Max Marginal Relevance. """ExampleSelector that selects examples based on Max Marginal Relevance.
This was shown to improve performance in this paper: This was shown to improve performance in this paper:

@ -1,7 +1,7 @@
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.prompts.base import ( from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
@ -12,7 +12,7 @@ from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
class FewShotPromptTemplate(StringPromptTemplate, BaseModel): class FewShotPromptTemplate(StringPromptTemplate):
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None examples: Optional[List[dict]] = None

@ -1,17 +1,14 @@
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.prompts.base import ( from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
)
from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
class FewShotPromptWithTemplates(StringPromptTemplate, BaseModel): class FewShotPromptWithTemplates(StringPromptTemplate):
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None examples: Optional[List[dict]] = None

@ -5,7 +5,7 @@ from pathlib import Path
from string import Formatter from string import Formatter
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.prompts.base import ( from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
@ -14,7 +14,7 @@ from langchain.prompts.base import (
) )
class PromptTemplate(StringPromptTemplate, BaseModel): class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM. """Schema to represent a prompt for an LLM.
Example: Example:

@ -6,9 +6,7 @@ from sqlalchemy.orm import Session
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.vectorstores.pgvector import PGVector from langchain.vectorstores.pgvector import PGVector
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
FakeEmbeddings,
)
CONNECTION_STRING = PGVector.connection_string_from_db_params( CONNECTION_STRING = PGVector.connection_string_from_db_params(
driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg2"), driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg2"),

@ -2,8 +2,6 @@
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents import AgentExecutor, AgentType, initialize_agent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
@ -11,7 +9,7 @@ from langchain.llms.base import LLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeListLLM(LLM, BaseModel): class FakeListLLM(LLM):
"""Fake LLM for testing that outputs elements of a list.""" """Fake LLM for testing that outputs elements of a list."""
responses: List[str] responses: List[str]

@ -2,8 +2,6 @@
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Union
from pydantic import BaseModel
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
@ -23,7 +21,7 @@ Made in 2022."""
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}") _FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
class FakeListLLM(LLM, BaseModel): class FakeListLLM(LLM):
"""Fake LLM for testing that outputs elements of a list.""" """Fake LLM for testing that outputs elements of a list."""
responses: List[str] responses: List[str]

@ -2,7 +2,6 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pytest import pytest
from pydantic import BaseModel
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -10,7 +9,7 @@ from langchain.schema import BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeMemory(BaseMemory, BaseModel): class FakeMemory(BaseMemory):
"""Fake memory class for testing purposes.""" """Fake memory class for testing purposes."""
@property @property
@ -33,7 +32,7 @@ class FakeMemory(BaseMemory, BaseModel):
pass pass
class FakeChain(Chain, BaseModel): class FakeChain(Chain):
"""Fake chain class for testing purposes.""" """Fake chain class for testing purposes."""
be_correct: bool = True be_correct: bool = True

@ -2,7 +2,6 @@
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
from pydantic import BaseModel
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.hyde.prompts import PROMPT_MAP
@ -23,7 +22,7 @@ class FakeEmbeddings(Embeddings):
return list(np.random.uniform(0, 1, 10)) return list(np.random.uniform(0, 1, 10))
class FakeLLM(BaseLLM, BaseModel): class FakeLLM(BaseLLM):
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
n: int = 1 n: int = 1

@ -2,13 +2,11 @@
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.chains.natbot.base import NatBotChain from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM from langchain.llms.base import LLM
class FakeLLM(LLM, BaseModel): class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:

@ -2,14 +2,13 @@
from typing import Dict, List from typing import Dict, List
import pytest import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory.simple import SimpleMemory from langchain.memory.simple import SimpleMemory
class FakeChain(Chain, BaseModel): class FakeChain(Chain):
"""Fake Chain for testing purposes.""" """Fake Chain for testing purposes."""
input_variables: List[str] input_variables: List[str]

@ -1,12 +1,10 @@
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.llms.base import LLM from langchain.llms.base import LLM
class FakeLLM(LLM, BaseModel): class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
queries: Optional[Mapping] = None queries: Optional[Mapping] = None

Loading…
Cancel
Save