pass callbacks along baby ai (#7908)

This commit is contained in:
Harrison Chase 2023-07-19 22:40:33 -07:00 committed by GitHub
parent a4c5914c9a
commit df84e1bb64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -60,7 +60,7 @@ class BabyAGI(Chain, BaseModel):
return []
def get_next_task(
self, result: str, task_description: str, objective: str
self, result: str, task_description: str, objective: str, **kwargs: Any
) -> List[Dict]:
"""Get the next task."""
task_names = [t["task_name"] for t in self.task_list]
@ -71,13 +71,16 @@ class BabyAGI(Chain, BaseModel):
task_description=task_description,
incomplete_tasks=incomplete_tasks,
objective=objective,
**kwargs,
)
new_tasks = response.split("\n")
return [
{"task_name": task_name} for task_name in new_tasks if task_name.strip()
]
def prioritize_tasks(self, this_task_id: int, objective: str) -> List[Dict]:
def prioritize_tasks(
self, this_task_id: int, objective: str, **kwargs: Any
) -> List[Dict]:
"""Prioritize tasks."""
task_names = [t["task_name"] for t in list(self.task_list)]
next_task_id = int(this_task_id) + 1
@ -85,6 +88,7 @@ class BabyAGI(Chain, BaseModel):
task_names=", ".join(task_names),
next_task_id=str(next_task_id),
objective=objective,
**kwargs,
)
new_tasks = response.split("\n")
prioritized_task_list = []
@ -107,11 +111,11 @@ class BabyAGI(Chain, BaseModel):
return []
return [str(item.metadata["task"]) for item in results]
def execute_task(self, objective: str, task: str, k: int = 5) -> str:
def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str:
"""Execute a task."""
context = self._get_top_tasks(query=objective, k=k)
return self.execution_chain.run(
objective=objective, context="\n".join(context), task=task
objective=objective, context="\n".join(context), task=task, **kwargs
)
def _call(
@ -120,6 +124,7 @@ class BabyAGI(Chain, BaseModel):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the agent."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
objective = inputs["objective"]
first_task = inputs.get("first_task", "Make a todo list")
self.add_task({"task_id": 1, "task_name": first_task})
@ -133,7 +138,9 @@ class BabyAGI(Chain, BaseModel):
self.print_next_task(task)
# Step 2: Execute the task
result = self.execute_task(objective, task["task_name"])
result = self.execute_task(
objective, task["task_name"], callbacks=_run_manager.get_child()
)
this_task_id = int(task["task_id"])
self.print_task_result(result)
@ -146,12 +153,21 @@ class BabyAGI(Chain, BaseModel):
)
# Step 4: Create new tasks and reprioritize task list
new_tasks = self.get_next_task(result, task["task_name"], objective)
new_tasks = self.get_next_task(
result,
task["task_name"],
objective,
callbacks=_run_manager.get_child(),
)
for new_task in new_tasks:
self.task_id_counter += 1
new_task.update({"task_id": self.task_id_counter})
self.add_task(new_task)
self.task_list = deque(self.prioritize_tasks(this_task_id, objective))
self.task_list = deque(
self.prioritize_tasks(
this_task_id, objective, callbacks=_run_manager.get_child()
)
)
num_iters += 1
if self.max_iterations is not None and num_iters == self.max_iterations:
print(