forked from Archives/langchain
Compare commits
22 Commits
main
...
harrison/c
Author | SHA1 | Date | |
---|---|---|---|
|
d3a6387ab9 | ||
|
3efee27e56 | ||
|
7d0b1cafd7 | ||
|
6fe6af7048 | ||
|
6953c2e707 | ||
|
3086b752a3 | ||
|
03e3cd468b | ||
|
7eb33690a9 | ||
|
23b8cfc123 | ||
|
db5c8e0c42 | ||
|
aae3609aa8 | ||
|
a3d2a2ec2a | ||
|
45d6de177e | ||
|
175a248506 | ||
|
b902bddb8a | ||
|
164806a844 | ||
|
e3edd74eab | ||
|
52490e2dcd | ||
|
7e36f28e78 | ||
|
5d43246694 | ||
|
36922318d3 | ||
|
46b31626b5 |
@ -51,7 +51,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"id": "0728f0d9",
|
"id": "0728f0d9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -69,7 +69,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 3,
|
||||||
"id": "ba4e7618",
|
"id": "ba4e7618",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -87,7 +87,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"id": "03208e2b",
|
"id": "03208e2b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -105,7 +105,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"id": "244ee75c",
|
"id": "244ee75c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -131,7 +131,7 @@
|
|||||||
"\u001b[0m\n",
|
"\u001b[0m\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||||
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n",
|
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n",
|
||||||
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -140,7 +140,7 @@
|
|||||||
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\""
|
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -166,12 +166,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.0"
|
"version": "3.10.9"
|
||||||
},
|
|
||||||
"vscode": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||||
from langchain.cache import BaseCache
|
from langchain.cache import BaseCache
|
||||||
|
from langchain.callbacks import set_default_callback_manager, set_handler
|
||||||
from langchain.chains import (
|
from langchain.chains import (
|
||||||
ConversationChain,
|
ConversationChain,
|
||||||
LLMBashChain,
|
LLMBashChain,
|
||||||
@ -19,7 +20,6 @@ from langchain.chains import (
|
|||||||
from langchain.docstore import InMemoryDocstore, Wikipedia
|
from langchain.docstore import InMemoryDocstore, Wikipedia
|
||||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
from langchain.logger import BaseLogger, StdOutLogger
|
|
||||||
from langchain.prompts import (
|
from langchain.prompts import (
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
FewShotPromptTemplate,
|
FewShotPromptTemplate,
|
||||||
@ -31,9 +31,9 @@ from langchain.sql_database import SQLDatabase
|
|||||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||||
|
|
||||||
logger: BaseLogger = StdOutLogger()
|
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
llm_cache: Optional[BaseCache] = None
|
llm_cache: Optional[BaseCache] = None
|
||||||
|
set_default_callback_manager()
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLMChain",
|
"LLMChain",
|
||||||
@ -65,4 +65,5 @@ __all__ = [
|
|||||||
"VectorDBQAWithSourcesChain",
|
"VectorDBQAWithSourcesChain",
|
||||||
"QAWithSourcesChain",
|
"QAWithSourcesChain",
|
||||||
"PALChain",
|
"PALChain",
|
||||||
|
"set_handler",
|
||||||
]
|
]
|
||||||
|
@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping
|
from langchain.input import get_color_mapping
|
||||||
@ -46,7 +46,7 @@ class Agent(BaseModel):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||||
) -> Union[AgentFinish, AgentAction]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -132,10 +132,19 @@ class Agent(BaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent:
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
llm: BaseLLM,
|
||||||
|
tools: List[Tool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
llm_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=cls.create_prompt(tools),
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
return cls(llm_chain=llm_chain)
|
return cls(llm_chain=llm_chain)
|
||||||
|
|
||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
@ -194,10 +203,16 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_agent_and_tools(
|
def from_agent_and_tools(
|
||||||
cls, agent: Agent, tools: List[Tool], **kwargs: Any
|
cls,
|
||||||
|
agent: Agent,
|
||||||
|
tools: List[Tool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Create from agent and tools."""
|
"""Create from agent and tools."""
|
||||||
return cls(agent=agent, tools=tools, **kwargs)
|
return cls(
|
||||||
|
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
@ -244,24 +259,31 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
langchain.logger.log_agent_end(output, color="green")
|
self.callback_manager.on_agent_end(output, color="green")
|
||||||
final_output = output.return_values
|
final_output = output.return_values
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
final_output["intermediate_steps"] = intermediate_steps
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
return final_output
|
return final_output
|
||||||
if self.verbose:
|
|
||||||
langchain.logger.log_agent_action(output, color="green")
|
|
||||||
# And then we lookup the tool
|
# And then we lookup the tool
|
||||||
if output.tool in name_to_tool_map:
|
if output.tool in name_to_tool_map:
|
||||||
chain = name_to_tool_map[output.tool]
|
chain = name_to_tool_map[output.tool]
|
||||||
|
if self.verbose:
|
||||||
|
self.callback_manager.on_tool_start(
|
||||||
|
{"name": str(chain)[:60] + "..."}, output, color="green"
|
||||||
|
)
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = chain(output.tool_input)
|
observation = chain(output.tool_input)
|
||||||
color = color_mapping[output.tool]
|
color = color_mapping[output.tool]
|
||||||
else:
|
else:
|
||||||
|
if self.verbose:
|
||||||
|
self.callback_manager.on_tool_start(
|
||||||
|
{"name": "N/A"}, output, color="green"
|
||||||
|
)
|
||||||
observation = f"{output.tool} is not a valid tool, try another one."
|
observation = f"{output.tool} is not a valid tool, try another one."
|
||||||
color = None
|
color = None
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
langchain.logger.log_agent_observation(
|
self.callback_manager.on_tool_end(
|
||||||
observation,
|
observation,
|
||||||
color=color,
|
color=color,
|
||||||
observation_prefix=self.agent.observation_prefix,
|
observation_prefix=self.agent.observation_prefix,
|
||||||
@ -272,6 +294,8 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
output = self.agent.return_stopped_response(
|
output = self.agent.return_stopped_response(
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
)
|
)
|
||||||
|
if self.verbose:
|
||||||
|
self.callback_manager.on_agent_end(output, color="green")
|
||||||
final_output = output.return_values
|
final_output = output.return_values
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
final_output["intermediate_steps"] = intermediate_steps
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
"""Load agent."""
|
"""Load agent."""
|
||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
from langchain.agents.react.base import ReActDocstoreAgent
|
from langchain.agents.react.base import ReActDocstoreAgent
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
AGENT_TO_CLASS = {
|
AGENT_TO_CLASS = {
|
||||||
@ -19,6 +20,7 @@ def initialize_agent(
|
|||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
agent: str = "zero-shot-react-description",
|
agent: str = "zero-shot-react-description",
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Load agent given tools and LLM.
|
"""Load agent given tools and LLM.
|
||||||
@ -28,6 +30,8 @@ def initialize_agent(
|
|||||||
llm: Language model to use as the agent.
|
llm: Language model to use as the agent.
|
||||||
agent: The agent to use. Valid options are:
|
agent: The agent to use. Valid options are:
|
||||||
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
||||||
|
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||||
|
not provided. Defaults to None.
|
||||||
**kwargs: Additional key word arguments to pass to the agent.
|
**kwargs: Additional key word arguments to pass to the agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -39,5 +43,12 @@ def initialize_agent(
|
|||||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||||
)
|
)
|
||||||
agent_cls = AGENT_TO_CLASS[agent]
|
agent_cls = AGENT_TO_CLASS[agent]
|
||||||
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
|
agent_obj = agent_cls.from_llm_and_tools(
|
||||||
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs)
|
llm, tools, callback_manager=callback_manager
|
||||||
|
)
|
||||||
|
return AgentExecutor.from_agent_and_tools(
|
||||||
|
agent=agent_obj,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
20
langchain/callbacks/__init__.py
Normal file
20
langchain/callbacks/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
"""Callback handlers that allow listening to events in LangChain."""
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
|
||||||
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_manager() -> BaseCallbackManager:
|
||||||
|
"""Return the shared callback manager."""
|
||||||
|
return SharedCallbackManager()
|
||||||
|
|
||||||
|
|
||||||
|
def set_handler(handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler."""
|
||||||
|
callback = get_callback_manager()
|
||||||
|
callback.set_handler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_callback_manager() -> None:
|
||||||
|
"""Set default callback manager."""
|
||||||
|
set_handler(StdOutCallbackHandler())
|
177
langchain/callbacks/base.py
Normal file
177
langchain/callbacks/base.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackHandler(BaseModel, ABC):
|
||||||
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
ignore_llm: bool = False
|
||||||
|
ignore_chain: bool = False
|
||||||
|
ignore_agent: bool = False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||||
|
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler as the only handler on the callback manager."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackManager(BaseCallbackManager):
|
||||||
|
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
handlers: List[BaseCallbackHandler]
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
handler.on_llm_error(error)
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
handler.on_chain_end(outputs)
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
handler.on_chain_error(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_tool_start(serialized, action, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_tool_error(error)
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on additional input from chains and agents."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_text(text, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
handler.on_agent_end(finish, **kwargs)
|
||||||
|
|
||||||
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
self.handlers.append(handler)
|
||||||
|
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
self.handlers.remove(handler)
|
||||||
|
|
||||||
|
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler as the only handler on the callback manager."""
|
||||||
|
self.handlers = [handler]
|
114
langchain/callbacks/shared.py
Normal file
114
langchain/callbacks/shared.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
"""A shared CallbackManager."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.base import (
|
||||||
|
BaseCallbackHandler,
|
||||||
|
BaseCallbackManager,
|
||||||
|
CallbackManager,
|
||||||
|
)
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton:
|
||||||
|
"""A thread-safe singleton class that can be inherited from."""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> Any:
|
||||||
|
"""Create a new shared instance of the class."""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
# Another thread could have created the instance
|
||||||
|
# before we acquired the lock. So check that the
|
||||||
|
# instance is still nonexistent.
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||||
|
"""A thread-safe singleton CallbackManager."""
|
||||||
|
|
||||||
|
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_end(response)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_error(error)
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_end(outputs)
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_error(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_error(error)
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on arbitrary text."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_text(text, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_agent_end(finish, **kwargs)
|
||||||
|
|
||||||
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a callback to the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.add_handler(callback)
|
||||||
|
|
||||||
|
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a callback from the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.remove_handler(callback)
|
||||||
|
|
||||||
|
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Set handler as the only handler on the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.handlers = [handler]
|
84
langchain/callbacks/stdout.py
Normal file
84
langchain/callbacks/stdout.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""Callback Handler that prints to std out."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.input import print_text
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that prints to std out."""
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out the prompts."""
|
||||||
|
print("Prompts after formatting:")
|
||||||
|
for prompt in prompts:
|
||||||
|
print_text(prompt, color="green", end="\n")
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out that we are entering a chain."""
|
||||||
|
class_name = serialized["name"]
|
||||||
|
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Print out that we finished a chain."""
|
||||||
|
print("\n\033[1m> Finished chain.\033[0m")
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
action: AgentAction,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Print out the log in specified color."""
|
||||||
|
print_text(action.log, color=color)
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""If not the final action, print out observation."""
|
||||||
|
print_text(f"\n{observation_prefix}")
|
||||||
|
print_text(output, color=color)
|
||||||
|
print_text(f"\n{llm_prefix}")
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
end: str = "",
|
||||||
|
**kwargs: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Run when agent ends."""
|
||||||
|
print_text(text, color=color, end=end)
|
||||||
|
|
||||||
|
def on_agent_end(
|
||||||
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
print_text(finish.log, color=color, end="\n")
|
77
langchain/callbacks/streamlit.py
Normal file
77
langchain/callbacks/streamlit.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Callback Handler that logs to streamlit."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that logs to streamlit."""
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out the prompts."""
|
||||||
|
st.write("Prompts after formatting:")
|
||||||
|
for prompt in prompts:
|
||||||
|
st.write(prompt)
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out that we are entering a chain."""
|
||||||
|
class_name = serialized["name"]
|
||||||
|
st.write(f"Entering new {class_name} chain...")
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Print out that we finished a chain."""
|
||||||
|
st.write("Finished chain.")
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
action: AgentAction,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Print out the log in specified color."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.markdown(action.log.replace("\n", " \n"))
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""If not the final action, print out observation."""
|
||||||
|
st.write(f"{observation_prefix}{output}")
|
||||||
|
st.write(llm_prefix)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on text."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.write(text.replace("\n", " \n"))
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.write(finish.log.replace("\n", " \n"))
|
@ -8,7 +8,6 @@ from pydantic import BaseModel, root_validator
|
|||||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.requests import RequestsWrapper
|
from langchain.requests import RequestsWrapper
|
||||||
|
|
||||||
@ -67,10 +66,10 @@ class APIChain(Chain, BaseModel):
|
|||||||
question=question, api_docs=self.api_docs
|
question=question, api_docs=self.api_docs
|
||||||
)
|
)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(api_url, color="green", end="\n")
|
self.callback_manager.on_text(api_url, color="green", end="\n")
|
||||||
api_response = self.requests_wrapper.run(api_url)
|
api_response = self.requests_wrapper.run(api_url)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(api_response, color="yellow", end="\n")
|
self.callback_manager.on_text(api_response, color="yellow", end="\n")
|
||||||
answer = self.api_answer_chain.predict(
|
answer = self.api_answer_chain.predict(
|
||||||
question=question,
|
question=question,
|
||||||
api_docs=self.api_docs,
|
api_docs=self.api_docs,
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field
|
from pydantic import BaseModel, Extra, Field, validator
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
|
from langchain.callbacks import get_callback_manager
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
|
||||||
|
|
||||||
class Memory(BaseModel, ABC):
|
class Memory(BaseModel, ABC):
|
||||||
@ -42,9 +44,36 @@ class Chain(BaseModel, ABC):
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
memory: Optional[Memory] = None
|
memory: Optional[Memory] = None
|
||||||
|
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||||
|
verbose: bool = Field(
|
||||||
|
default_factory=_get_verbosity
|
||||||
|
) # Whether to print the response text
|
||||||
|
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
class Config:
|
||||||
"""Whether to print out response text."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@validator("callback_manager", pre=True, always=True)
|
||||||
|
def set_callback_manager(
|
||||||
|
cls, callback_manager: Optional[BaseCallbackManager]
|
||||||
|
) -> BaseCallbackManager:
|
||||||
|
"""If callback manager is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as callback manager, which is a nice UX.
|
||||||
|
"""
|
||||||
|
return callback_manager or get_callback_manager()
|
||||||
|
|
||||||
|
@validator("verbose", pre=True, always=True)
|
||||||
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||||
|
"""If verbose is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as verbose to access the global setting.
|
||||||
|
"""
|
||||||
|
if verbose is None:
|
||||||
|
return _get_verbosity()
|
||||||
|
else:
|
||||||
|
return verbose
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -106,12 +135,12 @@ class Chain(BaseModel, ABC):
|
|||||||
inputs = dict(inputs, **external_context)
|
inputs = dict(inputs, **external_context)
|
||||||
self._validate_inputs(inputs)
|
self._validate_inputs(inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
self.callback_manager.on_chain_start(
|
||||||
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
{"name": self.__class__.__name__}, inputs
|
||||||
)
|
)
|
||||||
outputs = self._call(inputs)
|
outputs = self._call(inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
self.callback_manager.on_chain_end(outputs)
|
||||||
self._validate_outputs(outputs)
|
self._validate_outputs(outputs)
|
||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
self.memory.save_context(inputs, outputs)
|
self.memory.save_context(inputs, outputs)
|
||||||
|
@ -3,10 +3,10 @@ from typing import Any, Dict, List, Sequence, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.llms.base import BaseLLM, LLMResult
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(Chain, BaseModel):
|
class LLMChain(Chain, BaseModel):
|
||||||
@ -60,8 +60,6 @@ class LLMChain(Chain, BaseModel):
|
|||||||
for inputs in input_list:
|
for inputs in input_list:
|
||||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||||
prompt = self.prompt.format(**selected_inputs)
|
prompt = self.prompt.format(**selected_inputs)
|
||||||
if self.verbose:
|
|
||||||
langchain.logger.log_llm_inputs(selected_inputs, prompt)
|
|
||||||
if "stop" in inputs and inputs["stop"] != stop:
|
if "stop" in inputs and inputs["stop"] != stop:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `stop` is present in any inputs, should be present in all."
|
"If `stop` is present in any inputs, should be present in all."
|
||||||
@ -77,8 +75,6 @@ class LLMChain(Chain, BaseModel):
|
|||||||
for generation in response.generations:
|
for generation in response.generations:
|
||||||
# Get the text of the top generated string.
|
# Get the text of the top generated string.
|
||||||
response_str = generation[0].text
|
response_str = generation[0].text
|
||||||
if self.verbose:
|
|
||||||
langchain.logger.log_llm_response(response_str)
|
|
||||||
outputs.append({self.output_key: response_str})
|
outputs.append({self.output_key: response_str})
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.prompt import PROMPT
|
from langchain.chains.llm_bash.prompt import PROMPT
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.utilities.bash import BashProcess
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
@ -52,11 +51,11 @@ class LLMBashChain(Chain, BaseModel):
|
|||||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||||
bash_executor = BashProcess()
|
bash_executor = BashProcess()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(inputs[self.input_key])
|
self.callback_manager.on_text(inputs[self.input_key])
|
||||||
|
|
||||||
t = llm_executor.predict(question=inputs[self.input_key])
|
t = llm_executor.predict(question=inputs[self.input_key])
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(t, color="green")
|
self.callback_manager.on_text(t, color="green")
|
||||||
|
|
||||||
t = t.strip()
|
t = t.strip()
|
||||||
if t.startswith("```bash"):
|
if t.startswith("```bash"):
|
||||||
@ -69,8 +68,8 @@ class LLMBashChain(Chain, BaseModel):
|
|||||||
output = bash_executor.run(command_list)
|
output = bash_executor.run(command_list)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text("\nAnswer: ")
|
self.callback_manager.on_text("\nAnswer: ")
|
||||||
print_text(output, color="yellow")
|
self.callback_manager.on_text(output, color="yellow")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown format from LLM: {t}")
|
raise ValueError(f"unknown format from LLM: {t}")
|
||||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_math.prompt import PROMPT
|
from langchain.chains.llm_math.prompt import PROMPT
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.python import PythonREPL
|
from langchain.python import PythonREPL
|
||||||
|
|
||||||
@ -52,17 +51,17 @@ class LLMMathChain(Chain, BaseModel):
|
|||||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||||
python_executor = PythonREPL()
|
python_executor = PythonREPL()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(inputs[self.input_key])
|
self.callback_manager.on_text(inputs[self.input_key])
|
||||||
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])
|
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(t, color="green")
|
self.callback_manager.on_text(t, color="green")
|
||||||
t = t.strip()
|
t = t.strip()
|
||||||
if t.startswith("```python"):
|
if t.startswith("```python"):
|
||||||
code = t[9:-4]
|
code = t[9:-4]
|
||||||
output = python_executor.run(code)
|
output = python_executor.run(code)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text("\nAnswer: ")
|
self.callback_manager.on_text("\nAnswer: ")
|
||||||
print_text(output, color="yellow")
|
self.callback_manager.on_text(output, color="yellow")
|
||||||
answer = "Answer: " + output
|
answer = "Answer: " + output
|
||||||
elif t.startswith("Answer:"):
|
elif t.startswith("Answer:"):
|
||||||
answer = t
|
answer = t
|
||||||
|
@ -12,7 +12,6 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.python import PythonREPL
|
from langchain.python import PythonREPL
|
||||||
@ -53,7 +52,7 @@ class PALChain(Chain, BaseModel):
|
|||||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||||
code = llm_chain.predict(stop=[self.stop], **inputs)
|
code = llm_chain.predict(stop=[self.stop], **inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(code, color="green", end="\n")
|
self.callback_manager.on_text(code, color="green", end="\n")
|
||||||
repl = PythonREPL()
|
repl = PythonREPL()
|
||||||
res = repl.run(code + f"\n{self.get_answer_expr}")
|
res = repl.run(code + f"\n{self.get_answer_expr}")
|
||||||
return {self.output_key: res.strip()}
|
return {self.output_key: res.strip()}
|
||||||
|
@ -26,7 +26,7 @@ def _load_stuff_chain(
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "summaries",
|
document_variable_name: str = "summaries",
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> StuffDocumentsChain:
|
) -> StuffDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||||
@ -49,7 +49,7 @@ def _load_map_reduce_chain(
|
|||||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||||
reduce_llm: Optional[BaseLLM] = None,
|
reduce_llm: Optional[BaseLLM] = None,
|
||||||
collapse_llm: Optional[BaseLLM] = None,
|
collapse_llm: Optional[BaseLLM] = None,
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
@ -97,7 +97,7 @@ def _load_refine_chain(
|
|||||||
document_variable_name: str = "context_str",
|
document_variable_name: str = "context_str",
|
||||||
initial_response_name: str = "existing_answer",
|
initial_response_name: str = "existing_answer",
|
||||||
refine_llm: Optional[BaseLLM] = None,
|
refine_llm: Optional[BaseLLM] = None,
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RefineDocumentsChain:
|
) -> RefineDocumentsChain:
|
||||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
@ -115,7 +115,10 @@ def _load_refine_chain(
|
|||||||
|
|
||||||
|
|
||||||
def load_qa_with_sources_chain(
|
def load_qa_with_sources_chain(
|
||||||
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
|
llm: BaseLLM,
|
||||||
|
chain_type: str = "stuff",
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load question answering with sources chain.
|
"""Load question answering with sources chain.
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ def _load_stuff_chain(
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "context",
|
document_variable_name: str = "context",
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> StuffDocumentsChain:
|
) -> StuffDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||||
@ -48,7 +48,7 @@ def _load_map_reduce_chain(
|
|||||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||||
reduce_llm: Optional[BaseLLM] = None,
|
reduce_llm: Optional[BaseLLM] = None,
|
||||||
collapse_llm: Optional[BaseLLM] = None,
|
collapse_llm: Optional[BaseLLM] = None,
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
@ -94,7 +94,7 @@ def _load_refine_chain(
|
|||||||
document_variable_name: str = "context_str",
|
document_variable_name: str = "context_str",
|
||||||
initial_response_name: str = "existing_answer",
|
initial_response_name: str = "existing_answer",
|
||||||
refine_llm: Optional[BaseLLM] = None,
|
refine_llm: Optional[BaseLLM] = None,
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RefineDocumentsChain:
|
) -> RefineDocumentsChain:
|
||||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
@ -111,7 +111,10 @@ def _load_refine_chain(
|
|||||||
|
|
||||||
|
|
||||||
def load_qa_chain(
|
def load_qa_chain(
|
||||||
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
|
llm: BaseLLM,
|
||||||
|
chain_type: str = "stuff",
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load question answering chain.
|
"""Load question answering chain.
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from typing import Dict, List
|
|||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.input import get_color_mapping, print_text
|
from langchain.input import get_color_mapping
|
||||||
|
|
||||||
|
|
||||||
class SequentialChain(Chain, BaseModel):
|
class SequentialChain(Chain, BaseModel):
|
||||||
@ -133,5 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel):
|
|||||||
if self.strip_outputs:
|
if self.strip_outputs:
|
||||||
_input = _input.strip()
|
_input = _input.strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
self.callback_manager.on_text(
|
||||||
|
_input, color=color_mapping[str(i)], end="\n"
|
||||||
|
)
|
||||||
return {self.output_key: _input}
|
return {self.output_key: _input}
|
||||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.sql_database.prompt import PROMPT
|
from langchain.chains.sql_database.prompt import PROMPT
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
|||||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||||
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(input_text)
|
self.callback_manager.on_text(input_text)
|
||||||
llm_inputs = {
|
llm_inputs = {
|
||||||
"input": input_text,
|
"input": input_text,
|
||||||
"dialect": self.database.dialect,
|
"dialect": self.database.dialect,
|
||||||
@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel):
|
|||||||
}
|
}
|
||||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(sql_cmd, color="green")
|
self.callback_manager.on_text(sql_cmd, color="green")
|
||||||
result = self.database.run(sql_cmd)
|
result = self.database.run(sql_cmd)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text("\nSQLResult: ")
|
self.callback_manager.on_text("\nSQLResult: ")
|
||||||
print_text(result, color="yellow")
|
self.callback_manager.on_text(result, color="yellow")
|
||||||
print_text("\nAnswer:")
|
self.callback_manager.on_text("\nAnswer:")
|
||||||
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
||||||
llm_inputs["input"] = input_text
|
llm_inputs["input"] = input_text
|
||||||
final_result = llm_chain.predict(**llm_inputs)
|
final_result = llm_chain.predict(**llm_inputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(final_result, color="green")
|
self.callback_manager.on_text(final_result, color="green")
|
||||||
return {self.output_key: final_result}
|
return {self.output_key: final_result}
|
||||||
|
@ -22,7 +22,7 @@ def _load_stuff_chain(
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "text",
|
document_variable_name: str = "text",
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> StuffDocumentsChain:
|
) -> StuffDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||||
@ -44,7 +44,7 @@ def _load_map_reduce_chain(
|
|||||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||||
reduce_llm: Optional[BaseLLM] = None,
|
reduce_llm: Optional[BaseLLM] = None,
|
||||||
collapse_llm: Optional[BaseLLM] = None,
|
collapse_llm: Optional[BaseLLM] = None,
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
||||||
@ -90,7 +90,7 @@ def _load_refine_chain(
|
|||||||
document_variable_name: str = "text",
|
document_variable_name: str = "text",
|
||||||
initial_response_name: str = "existing_answer",
|
initial_response_name: str = "existing_answer",
|
||||||
refine_llm: Optional[BaseLLM] = None,
|
refine_llm: Optional[BaseLLM] = None,
|
||||||
verbose: bool = False,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RefineDocumentsChain:
|
) -> RefineDocumentsChain:
|
||||||
|
|
||||||
@ -108,7 +108,10 @@ def _load_refine_chain(
|
|||||||
|
|
||||||
|
|
||||||
def load_summarize_chain(
|
def load_summarize_chain(
|
||||||
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
|
llm: BaseLLM,
|
||||||
|
chain_type: str = "stuff",
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load summarizing chain.
|
"""Load summarizing chain.
|
||||||
|
|
||||||
|
@ -2,34 +2,55 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra, Field, validator
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.schema import Generation
|
from langchain.callbacks import get_callback_manager
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class LLMResult(NamedTuple):
|
def _get_verbosity() -> bool:
|
||||||
"""Class that contains all relevant information for an LLM Result."""
|
return langchain.verbose
|
||||||
|
|
||||||
generations: List[List[Generation]]
|
|
||||||
"""List of the things generated. This is List[List[]] because
|
|
||||||
each input could have multiple generations."""
|
|
||||||
llm_output: Optional[dict] = None
|
|
||||||
"""For arbitrary LLM provider specific output."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseModel, ABC):
|
class BaseLLM(BaseModel, ABC):
|
||||||
"""LLM wrapper should take in a prompt and return a string."""
|
"""LLM wrapper should take in a prompt and return a string."""
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
cache: Optional[bool] = None
|
||||||
|
verbose: bool = Field(default_factory=_get_verbosity)
|
||||||
|
"""Whether to print out response text."""
|
||||||
|
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@validator("callback_manager", pre=True, always=True)
|
||||||
|
def set_callback_manager(
|
||||||
|
cls, callback_manager: Optional[BaseCallbackManager]
|
||||||
|
) -> BaseCallbackManager:
|
||||||
|
"""If callback manager is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as callback manager, which is a nice UX.
|
||||||
|
"""
|
||||||
|
return callback_manager or get_callback_manager()
|
||||||
|
|
||||||
|
@validator("verbose", pre=True, always=True)
|
||||||
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||||
|
"""If verbose is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as verbose to access the global setting.
|
||||||
|
"""
|
||||||
|
if verbose is None:
|
||||||
|
return _get_verbosity()
|
||||||
|
else:
|
||||||
|
return verbose
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -48,7 +69,14 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
return self._generate(prompts, stop=stop)
|
if self.verbose:
|
||||||
|
self.callback_manager.on_llm_start(
|
||||||
|
{"name": self.__class__.__name__}, prompts
|
||||||
|
)
|
||||||
|
output = self._generate(prompts, stop=stop)
|
||||||
|
if self.verbose:
|
||||||
|
self.callback_manager.on_llm_end(output)
|
||||||
|
return output
|
||||||
params = self._llm_dict()
|
params = self._llm_dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
@ -62,7 +90,11 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
else:
|
else:
|
||||||
missing_prompts.append(prompt)
|
missing_prompts.append(prompt)
|
||||||
missing_prompt_idxs.append(i)
|
missing_prompt_idxs.append(i)
|
||||||
|
self.callback_manager.on_llm_start(
|
||||||
|
{"name": self.__class__.__name__}, missing_prompts
|
||||||
|
)
|
||||||
new_results = self._generate(missing_prompts, stop=stop)
|
new_results = self._generate(missing_prompts, stop=stop)
|
||||||
|
self.callback_manager.on_llm_end(new_results)
|
||||||
for i, result in enumerate(new_results.generations):
|
for i, result in enumerate(new_results.generations):
|
||||||
existing_prompts[i] = result
|
existing_prompts[i] = result
|
||||||
prompt = prompts[i]
|
prompt = prompts[i]
|
||||||
|
@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import BaseLLM, LLMResult
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
"""BETA: everything in here is highly experimental, do not rely on."""
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLogger:
|
|
||||||
"""Base logging interface."""
|
|
||||||
|
|
||||||
def log_agent_start(self, text: str, **kwargs: Any) -> None:
|
|
||||||
"""Log the start of an agent interaction."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
||||||
"""Log the end of an agent interaction."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
|
||||||
"""Log agent action decision."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_agent_observation(self, observation: str, **kwargs: Any) -> None:
|
|
||||||
"""Log agent observation."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
|
|
||||||
"""Log LLM inputs."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_llm_response(self, output: str, **kwargs: Any) -> None:
|
|
||||||
"""Log LLM response."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StdOutLogger(BaseLogger):
|
|
||||||
"""Interface for printing things to stdout."""
|
|
||||||
|
|
||||||
def log_agent_start(self, text: str, **kwargs: Any) -> None:
|
|
||||||
"""Print the text to start the agent."""
|
|
||||||
print_text(text)
|
|
||||||
|
|
||||||
def log_agent_action(
|
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Print the log of the action in a certain color."""
|
|
||||||
print_text(action.log, color=color)
|
|
||||||
|
|
||||||
def log_agent_observation(
|
|
||||||
self,
|
|
||||||
observation: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
observation_prefix: Optional[str] = None,
|
|
||||||
llm_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Print the observation in a special color."""
|
|
||||||
print_text(f"\n{observation_prefix}")
|
|
||||||
print_text(observation, color=color)
|
|
||||||
print_text(f"\n{llm_prefix}")
|
|
||||||
|
|
||||||
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
|
|
||||||
"""Print the prompt in green."""
|
|
||||||
print("Prompt after formatting:")
|
|
||||||
print_text(prompt, color="green", end="\n")
|
|
||||||
|
|
||||||
def log_agent_end(
|
|
||||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Log the end of an agent interaction."""
|
|
||||||
print_text(finish.log, color=color)
|
|
@ -1,6 +1,6 @@
|
|||||||
"""Common schema objects."""
|
"""Common schema objects."""
|
||||||
|
|
||||||
from typing import NamedTuple
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
|
|
||||||
class AgentAction(NamedTuple):
|
class AgentAction(NamedTuple):
|
||||||
@ -24,3 +24,13 @@ class Generation(NamedTuple):
|
|||||||
text: str
|
text: str
|
||||||
"""Generated text output."""
|
"""Generated text output."""
|
||||||
# TODO: add log probs
|
# TODO: add log probs
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResult(NamedTuple):
|
||||||
|
"""Class that contains all relevant information for an LLM Result."""
|
||||||
|
|
||||||
|
generations: List[List[Generation]]
|
||||||
|
"""List of the things generated. This is List[List[]] because
|
||||||
|
each input could have multiple generations."""
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
"""For arbitrary LLM provider specific output."""
|
||||||
|
@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.agents import Tool, initialize_agent
|
from langchain.agents import AgentExecutor, Tool, initialize_agent
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
class FakeListLLM(LLM, BaseModel):
|
class FakeListLLM(LLM, BaseModel):
|
||||||
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
|
|||||||
return "fake_list"
|
return "fake_list"
|
||||||
|
|
||||||
|
|
||||||
def test_agent_bad_action() -> None:
|
def _get_agent(**kwargs: Any) -> AgentExecutor:
|
||||||
"""Test react chain when bad action given."""
|
"""Get agent for testing."""
|
||||||
bad_action_name = "BadAction"
|
bad_action_name = "BadAction"
|
||||||
responses = [
|
responses = [
|
||||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||||
@ -44,30 +46,122 @@ def test_agent_bad_action() -> None:
|
|||||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
||||||
]
|
]
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
tools, fake_llm, agent="zero-shot-react-description", verbose=True
|
tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
|
||||||
)
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_bad_action() -> None:
|
||||||
|
"""Test react chain when bad action given."""
|
||||||
|
agent = _get_agent()
|
||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "curses foiled again"
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
|
|
||||||
def test_agent_stopped_early() -> None:
|
def test_agent_stopped_early() -> None:
|
||||||
"""Test react chain when bad action given."""
|
"""Test react chain when bad action given."""
|
||||||
bad_action_name = "BadAction"
|
agent = _get_agent(max_iterations=0)
|
||||||
|
output = agent.run("when was langchain made")
|
||||||
|
assert output == "Agent stopped due to max iterations."
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_callbacks_global() -> None:
|
||||||
|
"""Test react chain with callbacks by setting verbose globally."""
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
langchain.verbose = True
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler])
|
||||||
|
tool = "Search"
|
||||||
responses = [
|
responses = [
|
||||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||||
]
|
]
|
||||||
fake_llm = FakeListLLM(responses=responses)
|
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||||
tools = [
|
tools = [
|
||||||
Tool("Search", lambda x: x, "Useful for searching"),
|
Tool("Search", lambda x: x, "Useful for searching"),
|
||||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
|
||||||
]
|
]
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
tools,
|
tools,
|
||||||
fake_llm,
|
fake_llm,
|
||||||
agent="zero-shot-react-description",
|
agent="zero-shot-react-description",
|
||||||
verbose=True,
|
verbose=True,
|
||||||
max_iterations=0,
|
callback_manager=manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "Agent stopped due to max iterations."
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
|
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler.starts == 6
|
||||||
|
# 1 extra agent end
|
||||||
|
assert handler.ends == 7
|
||||||
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_callbacks_local() -> None:
|
||||||
|
"""Test react chain with callbacks by setting verbose locally."""
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
langchain.verbose = False
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler])
|
||||||
|
tool = "Search"
|
||||||
|
responses = [
|
||||||
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
|
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||||
|
]
|
||||||
|
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||||
|
tools = [
|
||||||
|
Tool("Search", lambda x: x, "Useful for searching"),
|
||||||
|
]
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools,
|
||||||
|
fake_llm,
|
||||||
|
agent="zero-shot-react-description",
|
||||||
|
verbose=True,
|
||||||
|
callback_manager=manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent.agent.llm_chain.verbose = True
|
||||||
|
|
||||||
|
output = agent.run("when was langchain made")
|
||||||
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
|
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler.starts == 6
|
||||||
|
# 1 extra agent end
|
||||||
|
assert handler.ends == 7
|
||||||
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_callbacks_not_verbose() -> None:
|
||||||
|
"""Test react chain with callbacks but not verbose."""
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
langchain.verbose = False
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler])
|
||||||
|
tool = "Search"
|
||||||
|
responses = [
|
||||||
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
|
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||||
|
]
|
||||||
|
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||||
|
tools = [
|
||||||
|
Tool("Search", lambda x: x, "Useful for searching"),
|
||||||
|
]
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools,
|
||||||
|
fake_llm,
|
||||||
|
agent="zero-shot-react-description",
|
||||||
|
callback_manager=manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = agent.run("when was langchain made")
|
||||||
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
|
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 0
|
||||||
|
1
tests/unit_tests/callbacks/__init__.py
Normal file
1
tests/unit_tests/callbacks/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for correct functioning of callbacks."""
|
66
tests/unit_tests/callbacks/fake_callback_handler.py
Normal file
66
tests/unit_tests/callbacks/fake_callback_handler.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""A fake callback handler for testing purposes."""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Fake callback handler for testing."""
|
||||||
|
|
||||||
|
starts: int = 0
|
||||||
|
ends: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent is ending."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends running."""
|
||||||
|
self.ends += 1
|
47
tests/unit_tests/callbacks/test_callback_manager.py
Normal file
47
tests/unit_tests/callbacks/test_callback_manager.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Test CallbackManager."""
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||||
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
from langchain.schema import AgentAction, LLMResult
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _test_callback_manager(
|
||||||
|
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||||
|
) -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
manager.on_llm_start({}, [])
|
||||||
|
manager.on_llm_end(LLMResult(generations=[]))
|
||||||
|
manager.on_llm_error(Exception())
|
||||||
|
manager.on_chain_start({"name": "foo"}, {})
|
||||||
|
manager.on_chain_end({})
|
||||||
|
manager.on_chain_error(Exception())
|
||||||
|
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||||
|
manager.on_tool_end("")
|
||||||
|
manager.on_tool_error(Exception())
|
||||||
|
for handler in handlers:
|
||||||
|
assert handler.starts == 3
|
||||||
|
assert handler.ends == 3
|
||||||
|
assert handler.errors == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_manager() -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
|
_test_callback_manager(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_callback_manager() -> None:
|
||||||
|
"""Test the SharedCallbackManager."""
|
||||||
|
manager1 = SharedCallbackManager()
|
||||||
|
manager2 = SharedCallbackManager()
|
||||||
|
|
||||||
|
assert manager1 is manager2
|
||||||
|
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager1.add_handler(handler1)
|
||||||
|
manager2.add_handler(handler2)
|
||||||
|
_test_callback_manager(manager1, handler1, handler2)
|
@ -4,7 +4,9 @@ from typing import Any, Dict, List
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
from langchain.chains.base import Chain, Memory
|
from langchain.chains.base import Chain, Memory
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
class FakeMemory(Memory, BaseModel):
|
class FakeMemory(Memory, BaseModel):
|
||||||
@ -133,3 +135,31 @@ def test_run_arg_with_memory() -> None:
|
|||||||
"""Test run method works when arg is passed."""
|
"""Test run method works when arg is passed."""
|
||||||
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
|
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
|
||||||
chain.run("bar")
|
chain.run("bar")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_callback() -> None:
|
||||||
|
"""Test run method works when callback manager is passed."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
chain = FakeChain(
|
||||||
|
callback_manager=CallbackManager(handlers=[handler]), verbose=True
|
||||||
|
)
|
||||||
|
output = chain.run("bar")
|
||||||
|
assert output == "baz"
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 1
|
||||||
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_callback_not_verbose() -> None:
|
||||||
|
"""Test run method works when callback manager is passed and not verbose."""
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
langchain.verbose = False
|
||||||
|
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
|
||||||
|
output = chain.run("bar")
|
||||||
|
assert output == "baz"
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 0
|
||||||
|
30
tests/unit_tests/llms/test_callbacks.py
Normal file
30
tests/unit_tests/llms/test_callbacks.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Test LLM callbacks."""
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_with_callbacks() -> None:
|
||||||
|
"""Test LLM callbacks."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
|
||||||
|
output = llm("foo")
|
||||||
|
assert output == "foo"
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 1
|
||||||
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_with_callbacks_not_verbose() -> None:
|
||||||
|
"""Test LLM callbacks but not verbose."""
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
langchain.verbose = False
|
||||||
|
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
|
||||||
|
output = llm("foo")
|
||||||
|
assert output == "foo"
|
||||||
|
assert handler.starts == 0
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 0
|
@ -7,8 +7,8 @@ from pydantic import BaseModel
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
||||||
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
||||||
from langchain.llms.base import BaseLLM, LLMResult
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbeddings(Embeddings):
|
class FakeEmbeddings(Embeddings):
|
||||||
|
Loading…
Reference in New Issue
Block a user