switch up defaults (#485)

i kinda like this just because we call `self.callback_manager` so many
times, and thats nicer than `self._get_callback_manager()`?
harrison/callback-updates
Harrison Chase 1 year ago committed by GitHub
parent 52490e2dcd
commit e3edd74eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -220,7 +220,7 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
self._get_callback_manager().on_agent_end(output.log, color="green")
self.callback_manager.on_agent_end(output.log, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
@ -230,7 +230,7 @@ class AgentExecutor(Chain, BaseModel):
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
if self.verbose:
self._get_callback_manager().on_tool_start(
self.callback_manager.on_tool_start(
{"name": str(chain)[:60] + "..."}, output, color="green"
)
# We then call the tool on the tool input to get an observation
@ -238,13 +238,13 @@ class AgentExecutor(Chain, BaseModel):
color = color_mapping[output.tool]
else:
if self.verbose:
self._get_callback_manager().on_tool_start(
self.callback_manager.on_tool_start(
{"name": "N/A"}, output, color="green"
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
if self.verbose:
self._get_callback_manager().on_tool_end(
self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.callbacks import get_callback_manager
@ -44,7 +44,7 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[Memory] = None
callback_manager: Optional[BaseCallbackManager] = None
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
@ -54,11 +54,15 @@ class Chain(BaseModel, ABC):
arbitrary_types_allowed = True
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as context manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@property
@abstractmethod
@ -120,12 +124,12 @@ class Chain(BaseModel, ABC):
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
self._get_callback_manager().on_chain_start(
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__}, inputs
)
outputs = self._call(inputs)
if self.verbose:
self._get_callback_manager().on_chain_end(outputs)
self.callback_manager.on_chain_end(outputs)
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
import yaml
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.callbacks import get_callback_manager
@ -23,7 +23,7 @@ class BaseLLM(BaseModel, ABC):
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: Optional[BaseCallbackManager] = None
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
@ -31,18 +31,22 @@ class BaseLLM(BaseModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as context manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@abstractmethod
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompts."""
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
@ -55,12 +59,12 @@ class BaseLLM(BaseModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`."
)
if self.verbose:
self._get_callback_manager().on_llm_start(
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
if self.verbose:
self._get_callback_manager().on_llm_end(output)
self.callback_manager.on_llm_end(output)
return output
params = self._llm_dict()
params["stop"] = stop
@ -75,11 +79,11 @@ class BaseLLM(BaseModel, ABC):
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
self._get_callback_manager().on_llm_start(
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
new_results = self._generate(missing_prompts, stop=stop)
self._get_callback_manager().on_llm_end(new_results)
self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations):
existing_prompts[i] = result
prompt = prompts[i]

Loading…
Cancel
Save