forked from Archives/langchain
Compare commits
22 Commits
main
...
harrison/c
Author | SHA1 | Date | |
---|---|---|---|
|
d3a6387ab9 | ||
|
3efee27e56 | ||
|
7d0b1cafd7 | ||
|
6fe6af7048 | ||
|
6953c2e707 | ||
|
3086b752a3 | ||
|
03e3cd468b | ||
|
7eb33690a9 | ||
|
23b8cfc123 | ||
|
db5c8e0c42 | ||
|
aae3609aa8 | ||
|
a3d2a2ec2a | ||
|
45d6de177e | ||
|
175a248506 | ||
|
b902bddb8a | ||
|
164806a844 | ||
|
e3edd74eab | ||
|
52490e2dcd | ||
|
7e36f28e78 | ||
|
5d43246694 | ||
|
36922318d3 | ||
|
46b31626b5 |
@ -51,7 +51,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "0728f0d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -69,7 +69,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "ba4e7618",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -87,7 +87,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "03208e2b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -105,7 +105,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "244ee75c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -131,7 +131,7 @@
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n",
|
||||
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -140,7 +140,7 @@
|
||||
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -166,12 +166,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.0"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103"
|
||||
}
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.cache import BaseCache
|
||||
from langchain.callbacks import set_default_callback_manager, set_handler
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMBashChain,
|
||||
@ -19,7 +20,6 @@ from langchain.chains import (
|
||||
from langchain.docstore import InMemoryDocstore, Wikipedia
|
||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.logger import BaseLogger, StdOutLogger
|
||||
from langchain.prompts import (
|
||||
BasePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
@ -31,9 +31,9 @@ from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
|
||||
logger: BaseLogger = StdOutLogger()
|
||||
verbose: bool = False
|
||||
llm_cache: Optional[BaseCache] = None
|
||||
set_default_callback_manager()
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
@ -65,4 +65,5 @@ __all__ = [
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"QAWithSourcesChain",
|
||||
"PALChain",
|
||||
"set_handler",
|
||||
]
|
||||
|
@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
import langchain
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
@ -46,7 +46,7 @@ class Agent(BaseModel):
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentFinish, AgentAction]:
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -132,10 +132,19 @@ class Agent(BaseModel):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent:
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
tools: List[Tool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=cls.create_prompt(tools),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return cls(llm_chain=llm_chain)
|
||||
|
||||
def return_stopped_response(
|
||||
@ -194,10 +203,16 @@ class AgentExecutor(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_agent_and_tools(
|
||||
cls, agent: Agent, tools: List[Tool], **kwargs: Any
|
||||
cls,
|
||||
agent: Agent,
|
||||
tools: List[Tool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Create from agent and tools."""
|
||||
return cls(agent=agent, tools=tools, **kwargs)
|
||||
return cls(
|
||||
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@ -244,24 +259,31 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.verbose:
|
||||
langchain.logger.log_agent_end(output, color="green")
|
||||
self.callback_manager.on_agent_end(output, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
if self.verbose:
|
||||
langchain.logger.log_agent_action(output, color="green")
|
||||
|
||||
# And then we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
chain = name_to_tool_map[output.tool]
|
||||
if self.verbose:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": str(chain)[:60] + "..."}, output, color="green"
|
||||
)
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = chain(output.tool_input)
|
||||
color = color_mapping[output.tool]
|
||||
else:
|
||||
if self.verbose:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": "N/A"}, output, color="green"
|
||||
)
|
||||
observation = f"{output.tool} is not a valid tool, try another one."
|
||||
color = None
|
||||
if self.verbose:
|
||||
langchain.logger.log_agent_observation(
|
||||
self.callback_manager.on_tool_end(
|
||||
observation,
|
||||
color=color,
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
@ -272,6 +294,8 @@ class AgentExecutor(Chain, BaseModel):
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
if self.verbose:
|
||||
self.callback_manager.on_agent_end(output, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Load agent."""
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
AGENT_TO_CLASS = {
|
||||
@ -19,6 +20,7 @@ def initialize_agent(
|
||||
tools: List[Tool],
|
||||
llm: BaseLLM,
|
||||
agent: str = "zero-shot-react-description",
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Load agent given tools and LLM.
|
||||
@ -28,6 +30,8 @@ def initialize_agent(
|
||||
llm: Language model to use as the agent.
|
||||
agent: The agent to use. Valid options are:
|
||||
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||
not provided. Defaults to None.
|
||||
**kwargs: Additional key word arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
@ -39,5 +43,12 @@ def initialize_agent(
|
||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||
)
|
||||
agent_cls = AGENT_TO_CLASS[agent]
|
||||
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs)
|
||||
agent_obj = agent_cls.from_llm_and_tools(
|
||||
llm, tools, callback_manager=callback_manager
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
20
langchain/callbacks/__init__.py
Normal file
20
langchain/callbacks/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
||||
|
||||
|
||||
def set_handler(handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler."""
|
||||
callback = get_callback_manager()
|
||||
callback.set_handler(handler)
|
||||
|
||||
|
||||
def set_default_callback_manager() -> None:
|
||||
"""Set default callback manager."""
|
||||
set_handler(StdOutCallbackHandler())
|
177
langchain/callbacks/base.py
Normal file
177
langchain/callbacks/base.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class BaseCallbackHandler(BaseModel, ABC):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
ignore_llm: bool = False
|
||||
ignore_chain: bool = False
|
||||
ignore_agent: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@abstractmethod
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
|
||||
@abstractmethod
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
|
||||
@abstractmethod
|
||||
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
handlers: List[BaseCallbackHandler]
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
handler.on_llm_end(response)
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
handler.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
handler.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
handler.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
handler.on_tool_start(serialized, action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
handler.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
handler.on_tool_error(error)
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
for handler in self.handlers:
|
||||
handler.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
handler.on_agent_end(finish, **kwargs)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.handlers = [handler]
|
114
langchain/callbacks/shared.py
Normal file
114
langchain/callbacks/shared.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""A shared CallbackManager."""
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
"""A thread-safe singleton class that can be inherited from."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Any:
|
||||
"""Create a new shared instance of the class."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
# Another thread could have created the instance
|
||||
# before we acquired the lock. So check that the
|
||||
# instance is still nonexistent.
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
"""A thread-safe singleton CallbackManager."""
|
||||
|
||||
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_end(response)
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(serialized, action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_error(error)
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_end(finish, **kwargs)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.add_handler(callback)
|
||||
|
||||
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Remove a callback from the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.remove_handler(callback)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.handlers = [handler]
|
84
langchain/callbacks/stdout.py
Normal file
84
langchain/callbacks/stdout.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print("Prompts after formatting:")
|
||||
for prompt in prompts:
|
||||
print_text(prompt, color="green", end="\n")
|
||||
|
||||
def on_llm_end(self, response: LLMResult) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
color: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
print_text(action.log, color=color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text(text, color=color, end=end)
|
||||
|
||||
def on_agent_end(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text(finish.log, color=color, end="\n")
|
77
langchain/callbacks/streamlit.py
Normal file
77
langchain/callbacks/streamlit.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Callback Handler that logs to streamlit."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to streamlit."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
st.write("Prompts after formatting:")
|
||||
for prompt in prompts:
|
||||
st.write(prompt)
|
||||
|
||||
def on_llm_end(self, response: LLMResult) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
st.write(f"Entering new {class_name} chain...")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
st.write("Finished chain.")
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
action: AgentAction,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
st.markdown(action.log.replace("\n", " \n"))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
st.write(f"{observation_prefix}{output}")
|
||||
st.write(llm_prefix)
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on text."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
st.write(text.replace("\n", " \n"))
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
st.write(finish.log.replace("\n", " \n"))
|
@ -8,7 +8,6 @@ from pydantic import BaseModel, root_validator
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.requests import RequestsWrapper
|
||||
|
||||
@ -67,10 +66,10 @@ class APIChain(Chain, BaseModel):
|
||||
question=question, api_docs=self.api_docs
|
||||
)
|
||||
if self.verbose:
|
||||
print_text(api_url, color="green", end="\n")
|
||||
self.callback_manager.on_text(api_url, color="green", end="\n")
|
||||
api_response = self.requests_wrapper.run(api_url)
|
||||
if self.verbose:
|
||||
print_text(api_response, color="yellow", end="\n")
|
||||
self.callback_manager.on_text(api_response, color="yellow", end="\n")
|
||||
answer = self.api_answer_chain.predict(
|
||||
question=question,
|
||||
api_docs=self.api_docs,
|
||||
|
@ -2,9 +2,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
||||
|
||||
class Memory(BaseModel, ABC):
|
||||
@ -42,9 +44,36 @@ class Chain(BaseModel, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[Memory] = None
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -106,12 +135,12 @@ class Chain(BaseModel, ABC):
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
print(
|
||||
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
||||
self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__}, inputs
|
||||
)
|
||||
outputs = self._call(inputs)
|
||||
if self.verbose:
|
||||
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
||||
self.callback_manager.on_chain_end(outputs)
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
|
@ -3,10 +3,10 @@ from typing import Any, Dict, List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
import langchain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.base import BaseLLM, LLMResult
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@ -60,8 +60,6 @@ class LLMChain(Chain, BaseModel):
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
if self.verbose:
|
||||
langchain.logger.log_llm_inputs(selected_inputs, prompt)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@ -77,8 +75,6 @@ class LLMChain(Chain, BaseModel):
|
||||
for generation in response.generations:
|
||||
# Get the text of the top generated string.
|
||||
response_str = generation[0].text
|
||||
if self.verbose:
|
||||
langchain.logger.log_llm_response(response_str)
|
||||
outputs.append({self.output_key: response_str})
|
||||
return outputs
|
||||
|
||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
@ -52,11 +51,11 @@ class LLMBashChain(Chain, BaseModel):
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
bash_executor = BashProcess()
|
||||
if self.verbose:
|
||||
print_text(inputs[self.input_key])
|
||||
self.callback_manager.on_text(inputs[self.input_key])
|
||||
|
||||
t = llm_executor.predict(question=inputs[self.input_key])
|
||||
if self.verbose:
|
||||
print_text(t, color="green")
|
||||
self.callback_manager.on_text(t, color="green")
|
||||
|
||||
t = t.strip()
|
||||
if t.startswith("```bash"):
|
||||
@ -69,8 +68,8 @@ class LLMBashChain(Chain, BaseModel):
|
||||
output = bash_executor.run(command_list)
|
||||
|
||||
if self.verbose:
|
||||
print_text("\nAnswer: ")
|
||||
print_text(output, color="yellow")
|
||||
self.callback_manager.on_text("\nAnswer: ")
|
||||
self.callback_manager.on_text(output, color="yellow")
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {t}")
|
||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
@ -52,17 +51,17 @@ class LLMMathChain(Chain, BaseModel):
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
python_executor = PythonREPL()
|
||||
if self.verbose:
|
||||
print_text(inputs[self.input_key])
|
||||
self.callback_manager.on_text(inputs[self.input_key])
|
||||
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])
|
||||
if self.verbose:
|
||||
print_text(t, color="green")
|
||||
self.callback_manager.on_text(t, color="green")
|
||||
t = t.strip()
|
||||
if t.startswith("```python"):
|
||||
code = t[9:-4]
|
||||
output = python_executor.run(code)
|
||||
if self.verbose:
|
||||
print_text("\nAnswer: ")
|
||||
print_text(output, color="yellow")
|
||||
self.callback_manager.on_text("\nAnswer: ")
|
||||
self.callback_manager.on_text(output, color="yellow")
|
||||
answer = "Answer: " + output
|
||||
elif t.startswith("Answer:"):
|
||||
answer = t
|
||||
|
@ -12,7 +12,6 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.python import PythonREPL
|
||||
@ -53,7 +52,7 @@ class PALChain(Chain, BaseModel):
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
code = llm_chain.predict(stop=[self.stop], **inputs)
|
||||
if self.verbose:
|
||||
print_text(code, color="green", end="\n")
|
||||
self.callback_manager.on_text(code, color="green", end="\n")
|
||||
repl = PythonREPL()
|
||||
res = repl.run(code + f"\n{self.get_answer_expr}")
|
||||
return {self.output_key: res.strip()}
|
||||
|
@ -26,7 +26,7 @@ def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "summaries",
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
@ -49,7 +49,7 @@ def _load_map_reduce_chain(
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
@ -97,7 +97,7 @@ def _load_refine_chain(
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
@ -115,7 +115,10 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_with_sources_chain(
|
||||
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering with sources chain.
|
||||
|
||||
|
@ -26,7 +26,7 @@ def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "context",
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
@ -48,7 +48,7 @@ def _load_map_reduce_chain(
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
@ -94,7 +94,7 @@ def _load_refine_chain(
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
@ -111,7 +111,10 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
from langchain.input import get_color_mapping
|
||||
|
||||
|
||||
class SequentialChain(Chain, BaseModel):
|
||||
@ -133,5 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel):
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
if self.verbose:
|
||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||
self.callback_manager.on_text(
|
||||
_input, color=color_mapping[str(i)], end="\n"
|
||||
)
|
||||
return {self.output_key: _input}
|
||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
||||
if self.verbose:
|
||||
print_text(input_text)
|
||||
self.callback_manager.on_text(input_text)
|
||||
llm_inputs = {
|
||||
"input": input_text,
|
||||
"dialect": self.database.dialect,
|
||||
@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
}
|
||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||
if self.verbose:
|
||||
print_text(sql_cmd, color="green")
|
||||
self.callback_manager.on_text(sql_cmd, color="green")
|
||||
result = self.database.run(sql_cmd)
|
||||
if self.verbose:
|
||||
print_text("\nSQLResult: ")
|
||||
print_text(result, color="yellow")
|
||||
print_text("\nAnswer:")
|
||||
self.callback_manager.on_text("\nSQLResult: ")
|
||||
self.callback_manager.on_text(result, color="yellow")
|
||||
self.callback_manager.on_text("\nAnswer:")
|
||||
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
||||
llm_inputs["input"] = input_text
|
||||
final_result = llm_chain.predict(**llm_inputs)
|
||||
if self.verbose:
|
||||
print_text(final_result, color="green")
|
||||
self.callback_manager.on_text(final_result, color="green")
|
||||
return {self.output_key: final_result}
|
||||
|
@ -22,7 +22,7 @@ def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
@ -44,7 +44,7 @@ def _load_map_reduce_chain(
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
||||
@ -90,7 +90,7 @@ def _load_refine_chain(
|
||||
document_variable_name: str = "text",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
verbose: bool = False,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
|
||||
@ -108,7 +108,10 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_summarize_chain(
|
||||
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load summarizing chain.
|
||||
|
||||
|
@ -2,34 +2,55 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
|
||||
import langchain
|
||||
from langchain.schema import Generation
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
|
||||
class LLMResult(NamedTuple):
|
||||
"""Class that contains all relevant information for an LLM Result."""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
"""List of the things generated. This is List[List[]] because
|
||||
each input could have multiple generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
@ -48,7 +69,14 @@ class BaseLLM(BaseModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
return self._generate(prompts, stop=stop)
|
||||
if self.verbose:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
output = self._generate(prompts, stop=stop)
|
||||
if self.verbose:
|
||||
self.callback_manager.on_llm_end(output)
|
||||
return output
|
||||
params = self._llm_dict()
|
||||
params["stop"] = stop
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
@ -62,7 +90,11 @@ class BaseLLM(BaseModel, ABC):
|
||||
else:
|
||||
missing_prompts.append(prompt)
|
||||
missing_prompt_idxs.append(i)
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts
|
||||
)
|
||||
new_results = self._generate(missing_prompts, stop=stop)
|
||||
self.callback_manager.on_llm_end(new_results)
|
||||
for i, result in enumerate(new_results.generations):
|
||||
existing_prompts[i] = result
|
||||
prompt = prompts[i]
|
||||
|
@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.llms.base import BaseLLM, LLMResult
|
||||
from langchain.schema import Generation
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
@ -1,71 +0,0 @@
|
||||
"""BETA: everything in here is highly experimental, do not rely on."""
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
|
||||
|
||||
class BaseLogger:
|
||||
"""Base logging interface."""
|
||||
|
||||
def log_agent_start(self, text: str, **kwargs: Any) -> None:
|
||||
"""Log the start of an agent interaction."""
|
||||
pass
|
||||
|
||||
def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Log the end of an agent interaction."""
|
||||
pass
|
||||
|
||||
def log_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Log agent action decision."""
|
||||
pass
|
||||
|
||||
def log_agent_observation(self, observation: str, **kwargs: Any) -> None:
|
||||
"""Log agent observation."""
|
||||
pass
|
||||
|
||||
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
|
||||
"""Log LLM inputs."""
|
||||
pass
|
||||
|
||||
def log_llm_response(self, output: str, **kwargs: Any) -> None:
|
||||
"""Log LLM response."""
|
||||
pass
|
||||
|
||||
|
||||
class StdOutLogger(BaseLogger):
|
||||
"""Interface for printing things to stdout."""
|
||||
|
||||
def log_agent_start(self, text: str, **kwargs: Any) -> None:
|
||||
"""Print the text to start the agent."""
|
||||
print_text(text)
|
||||
|
||||
def log_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Print the log of the action in a certain color."""
|
||||
print_text(action.log, color=color)
|
||||
|
||||
def log_agent_observation(
|
||||
self,
|
||||
observation: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print the observation in a special color."""
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(observation, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
|
||||
"""Print the prompt in green."""
|
||||
print("Prompt after formatting:")
|
||||
print_text(prompt, color="green", end="\n")
|
||||
|
||||
def log_agent_end(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Log the end of an agent interaction."""
|
||||
print_text(finish.log, color=color)
|
@ -1,6 +1,6 @@
|
||||
"""Common schema objects."""
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
|
||||
class AgentAction(NamedTuple):
|
||||
@ -24,3 +24,13 @@ class Generation(NamedTuple):
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
class LLMResult(NamedTuple):
|
||||
"""Class that contains all relevant information for an LLM Result."""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
"""List of the things generated. This is List[List[]] because
|
||||
each input could have multiple generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import Tool, initialize_agent
|
||||
from langchain.agents import AgentExecutor, Tool, initialize_agent
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.llms.base import LLM
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeListLLM(LLM, BaseModel):
|
||||
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
|
||||
return "fake_list"
|
||||
|
||||
|
||||
def test_agent_bad_action() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
def _get_agent(**kwargs: Any) -> AgentExecutor:
|
||||
"""Get agent for testing."""
|
||||
bad_action_name = "BadAction"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
@ -44,30 +46,122 @@ def test_agent_bad_action() -> None:
|
||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools, fake_llm, agent="zero-shot-react-description", verbose=True
|
||||
tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def test_agent_bad_action() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
agent = _get_agent()
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
|
||||
def test_agent_stopped_early() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
bad_action_name = "BadAction"
|
||||
agent = _get_agent(max_iterations=0)
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "Agent stopped due to max iterations."
|
||||
|
||||
|
||||
def test_agent_with_callbacks_global() -> None:
|
||||
"""Test react chain with callbacks by setting verbose globally."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = True
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
verbose=True,
|
||||
max_iterations=0,
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "Agent stopped due to max iterations."
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 6
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_agent_with_callbacks_local() -> None:
|
||||
"""Test react chain with callbacks by setting verbose locally."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
verbose=True,
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
agent.agent.llm_chain.verbose = True
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 6
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_agent_with_callbacks_not_verbose() -> None:
|
||||
"""Test react chain with callbacks but not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
1
tests/unit_tests/callbacks/__init__.py
Normal file
1
tests/unit_tests/callbacks/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for correct functioning of callbacks."""
|
66
tests/unit_tests/callbacks/fake_callback_handler.py
Normal file
66
tests/unit_tests/callbacks/fake_callback_handler.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
errors: int = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.ends += 1
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.ends += 1
|
47
tests/unit_tests/callbacks/test_callback_manager.py
Normal file
47
tests/unit_tests/callbacks/test_callback_manager.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Test CallbackManager."""
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
for handler in handlers:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 3
|
||||
assert handler.errors == 3
|
||||
|
||||
|
||||
def test_callback_manager() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_shared_callback_manager() -> None:
|
||||
"""Test the SharedCallbackManager."""
|
||||
manager1 = SharedCallbackManager()
|
||||
manager2 = SharedCallbackManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
@ -4,7 +4,9 @@ from typing import Any, Dict, List
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.chains.base import Chain, Memory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeMemory(Memory, BaseModel):
|
||||
@ -133,3 +135,31 @@ def test_run_arg_with_memory() -> None:
|
||||
"""Test run method works when arg is passed."""
|
||||
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
|
||||
chain.run("bar")
|
||||
|
||||
|
||||
def test_run_with_callback() -> None:
|
||||
"""Test run method works when callback manager is passed."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(
|
||||
callback_manager=CallbackManager(handlers=[handler]), verbose=True
|
||||
)
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_run_with_callback_not_verbose() -> None:
|
||||
"""Test run method works when callback manager is passed and not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
30
tests/unit_tests/llms/test_callbacks.py
Normal file
30
tests/unit_tests/llms/test_callbacks.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Test LLM callbacks."""
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_llm_with_callbacks() -> None:
|
||||
"""Test LLM callbacks."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_llm_with_callbacks_not_verbose() -> None:
|
||||
"""Test LLM callbacks but not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
@ -7,8 +7,8 @@ from pydantic import BaseModel
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
||||
from langchain.llms.base import BaseLLM, LLMResult
|
||||
from langchain.schema import Generation
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
|
Loading…
Reference in New Issue
Block a user