mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
pass callbacks along baby ai (#7908)
This commit is contained in:
parent
a4c5914c9a
commit
df84e1bb64
@ -60,7 +60,7 @@ class BabyAGI(Chain, BaseModel):
|
||||
return []
|
||||
|
||||
def get_next_task(
|
||||
self, result: str, task_description: str, objective: str
|
||||
self, result: str, task_description: str, objective: str, **kwargs: Any
|
||||
) -> List[Dict]:
|
||||
"""Get the next task."""
|
||||
task_names = [t["task_name"] for t in self.task_list]
|
||||
@ -71,13 +71,16 @@ class BabyAGI(Chain, BaseModel):
|
||||
task_description=task_description,
|
||||
incomplete_tasks=incomplete_tasks,
|
||||
objective=objective,
|
||||
**kwargs,
|
||||
)
|
||||
new_tasks = response.split("\n")
|
||||
return [
|
||||
{"task_name": task_name} for task_name in new_tasks if task_name.strip()
|
||||
]
|
||||
|
||||
def prioritize_tasks(self, this_task_id: int, objective: str) -> List[Dict]:
|
||||
def prioritize_tasks(
|
||||
self, this_task_id: int, objective: str, **kwargs: Any
|
||||
) -> List[Dict]:
|
||||
"""Prioritize tasks."""
|
||||
task_names = [t["task_name"] for t in list(self.task_list)]
|
||||
next_task_id = int(this_task_id) + 1
|
||||
@ -85,6 +88,7 @@ class BabyAGI(Chain, BaseModel):
|
||||
task_names=", ".join(task_names),
|
||||
next_task_id=str(next_task_id),
|
||||
objective=objective,
|
||||
**kwargs,
|
||||
)
|
||||
new_tasks = response.split("\n")
|
||||
prioritized_task_list = []
|
||||
@ -107,11 +111,11 @@ class BabyAGI(Chain, BaseModel):
|
||||
return []
|
||||
return [str(item.metadata["task"]) for item in results]
|
||||
|
||||
def execute_task(self, objective: str, task: str, k: int = 5) -> str:
|
||||
def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str:
|
||||
"""Execute a task."""
|
||||
context = self._get_top_tasks(query=objective, k=k)
|
||||
return self.execution_chain.run(
|
||||
objective=objective, context="\n".join(context), task=task
|
||||
objective=objective, context="\n".join(context), task=task, **kwargs
|
||||
)
|
||||
|
||||
def _call(
|
||||
@ -120,6 +124,7 @@ class BabyAGI(Chain, BaseModel):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the agent."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
objective = inputs["objective"]
|
||||
first_task = inputs.get("first_task", "Make a todo list")
|
||||
self.add_task({"task_id": 1, "task_name": first_task})
|
||||
@ -133,7 +138,9 @@ class BabyAGI(Chain, BaseModel):
|
||||
self.print_next_task(task)
|
||||
|
||||
# Step 2: Execute the task
|
||||
result = self.execute_task(objective, task["task_name"])
|
||||
result = self.execute_task(
|
||||
objective, task["task_name"], callbacks=_run_manager.get_child()
|
||||
)
|
||||
this_task_id = int(task["task_id"])
|
||||
self.print_task_result(result)
|
||||
|
||||
@ -146,12 +153,21 @@ class BabyAGI(Chain, BaseModel):
|
||||
)
|
||||
|
||||
# Step 4: Create new tasks and reprioritize task list
|
||||
new_tasks = self.get_next_task(result, task["task_name"], objective)
|
||||
new_tasks = self.get_next_task(
|
||||
result,
|
||||
task["task_name"],
|
||||
objective,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
for new_task in new_tasks:
|
||||
self.task_id_counter += 1
|
||||
new_task.update({"task_id": self.task_id_counter})
|
||||
self.add_task(new_task)
|
||||
self.task_list = deque(self.prioritize_tasks(this_task_id, objective))
|
||||
self.task_list = deque(
|
||||
self.prioritize_tasks(
|
||||
this_task_id, objective, callbacks=_run_manager.get_child()
|
||||
)
|
||||
)
|
||||
num_iters += 1
|
||||
if self.max_iterations is not None and num_iters == self.max_iterations:
|
||||
print(
|
||||
|
Loading…
Reference in New Issue
Block a user