This commit is contained in:
Bagatur 2023-08-09 14:44:29 -07:00
commit 05cdd22c39
15 changed files with 280 additions and 50 deletions

View File

@ -12,7 +12,7 @@ Here are the agents available in LangChain.
### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html)
This agent uses the [ReAct](https://arxiv.org/pdf/2205.00445.pdf) framework to determine which tool to use
This agent uses the [ReAct](https://arxiv.org/pdf/2210.03629) framework to determine which tool to use
based solely on the tool's description. Any number of tools can be provided.
This agent requires that a description is provided for each tool.

View File

@ -62,9 +62,11 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
_config: Dict[str, Any] = dict(config) if config else {}
_config.pop("_locals", None)
return self(input, **_config, **kwargs)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return self(input, **config_kwargs, **kwargs)
async def ainvoke(
self,
@ -77,10 +79,11 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
_config: Dict[str, Any] = dict(config) if config else {}
_config.pop("_locals", None)
return await self.acall(input, **_config, **kwargs)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return await self.acall(input, **config_kwargs, **kwargs)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.

View File

@ -103,14 +103,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
).generations[0][0],
).message,
)
@ -129,10 +131,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@ -1,10 +1,9 @@
"""Loads local airbyte json files."""
from typing import Any, Callable, Iterator, List, Mapping, Optional
from libs.langchain.langchain.utils.utils import guard_import
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.utils.utils import guard_import
RecordHandler = Callable[[Any, Optional[str]], Document]

View File

@ -219,10 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
result = self.generate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
)
return result.generations[0][0].text
@ -240,10 +242,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
)
return llm_result.generations[0][0].text

View File

@ -0,0 +1,46 @@
from operator import itemgetter
from typing import Any, Callable, List, Mapping, Optional, Union
from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
class OpenAIFunction(TypedDict):
"""A function description for ChatOpenAI"""
name: str
"""The name of the function."""
description: str
"""The description of the function."""
parameters: dict
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
"""A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]]
def __init__(
self,
runnables: Mapping[
str,
Union[
Runnable[dict, Any],
Callable[[dict], Any],
],
],
functions: Optional[List[OpenAIFunction]] = None,
):
if functions is not None:
assert len(functions) == len(runnables)
assert all(func["name"] in runnables for func in functions)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
| RouterRunnable(runnables)
)
super().__init__(bound=router, kwargs={}, functions=functions)

View File

@ -107,9 +107,11 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
return self.get_relevant_documents(input, **_config)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return self.get_relevant_documents(input, **config_kwargs)
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None
@ -118,9 +120,11 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
return await self.aget_relevant_documents(input, **_config)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return await self.aget_relevant_documents(input, **config_kwargs)
@abstractmethod
def _get_relevant_documents(

View File

@ -1229,7 +1229,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnableBinding(Serializable, Runnable[Input, Output]):
"""
A runnable that binds a runnable to a set of kwargs.
A runnable that delegates calls to another runnable with a set of kwargs.
"""
bound: Runnable[Input, Output]
@ -1314,8 +1314,15 @@ class RouterRunnable(
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
super().__init__(runnables=runnables)
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True

View File

@ -502,6 +502,18 @@ def _construct_run_evaluator(
return run_evaluator
def _get_keys(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_outputs: Optional[List[str]],
example_outputs: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
return input_key, prediction_key, reference_key
def _load_run_evaluators(
config: RunEvalConfig,
run_type: str,
@ -521,9 +533,13 @@ def _load_run_evaluators(
"""
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = []
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
input_key, prediction_key, reference_key = None, None, None
if config.evaluators or any(
[isinstance(e, EvaluatorType) for e in config.evaluators]
):
input_key, prediction_key, reference_key = _get_keys(
config, run_inputs, run_outputs, example_outputs
)
for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator(
eval_config,
@ -1078,15 +1094,15 @@ def _run_on_examples(
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, Any] = {}
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, wrapped_model)
tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type
wrapped_model, examples, evaluation, data_type
)
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
@ -1095,7 +1111,7 @@ def _run_on_examples(
for i, example in enumerate(examples):
result = _run_llm_or_chain(
example,
llm_or_chain_factory,
wrapped_model,
num_repetitions,
tags=tags,
callbacks=callbacks,
@ -1118,8 +1134,8 @@ def _prepare_eval_run(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, wrapped_model)
try:
project = client.create_project(project_name)
except ValueError as e:
@ -1134,7 +1150,7 @@ def _prepare_eval_run(
)
dataset = client.read_dataset(dataset_name=dataset_name)
examples = client.list_examples(dataset_id=str(dataset.id))
return llm_or_chain_factory, project_name, dataset, examples
return wrapped_model, project_name, dataset, examples
async def arun_on_dataset(
@ -1260,13 +1276,13 @@ async def arun_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
results = await _arun_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
@ -1427,14 +1443,14 @@ def run_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
if concurrency_level in (0, 1):
results = _run_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
@ -1448,7 +1464,7 @@ def run_on_dataset(
coro = _arun_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,

View File

@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
**kwargs: Any,
) -> Any:
config = config or {}
return self.run(input, **config, **kwargs)
return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke(
self,
@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
return super().ainvoke(input, config, **kwargs)
config = config or {}
return await self.arun(input, **config, **kwargs)
return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
# --- Tool ---

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.259"
version = "0.0.260"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

View File

@ -0,0 +1,9 @@
"""Test the airbyte document loader.
Light test to ensure that the airbyte document loader can be imported.
"""
def test_airbyte_import() -> None:
"""Test that the airbyte document loader can be imported."""
from langchain.document_loaders import airbyte # noqa

View File

@ -0,0 +1,31 @@
# serializer version: 1
# name: test_openai_functions_router
list([
dict({
'description': 'Sends the draft for revision.',
'name': 'revise',
'parameters': dict({
'properties': dict({
'notes': dict({
'description': "The editor's notes to guide the revision.",
'type': 'string',
}),
}),
'type': 'object',
}),
}),
dict({
'description': 'Accepts the draft.',
'name': 'accept',
'parameters': dict({
'properties': dict({
'draft': dict({
'description': 'The draft to accept.',
'type': 'string',
}),
}),
'type': 'object',
}),
}),
])
# ---

View File

@ -0,0 +1,95 @@
from typing import Any, List, Optional
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
from langchain.schema import ChatResult
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.output import ChatGeneration
class FakeChatOpenAI(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "accept",
"arguments": '{\n "draft": "turtles"\n}',
}
},
)
)
]
)
def test_openai_functions_router(
snapshot: SnapshotAssertion, mocker: MockerFixture
) -> None:
revise = mocker.Mock(
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
)
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
router = OpenAIFunctionsRouter(
{
"revise": revise,
"accept": accept,
},
functions=[
{
"name": "revise",
"description": "Sends the draft for revision.",
"parameters": {
"type": "object",
"properties": {
"notes": {
"type": "string",
"description": "The editor's notes to guide the revision.",
},
},
},
},
{
"name": "accept",
"description": "Accepts the draft.",
"parameters": {
"type": "object",
"properties": {
"draft": {
"type": "string",
"description": "The draft to accept.",
},
},
},
},
],
)
model = FakeChatOpenAI()
chain = model.bind(functions=router.functions) | router
assert router.functions == snapshot
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
revise.assert_not_called()
accept.assert_called_once_with({"draft": "turtles"})