From 1ef3ab4d0e663be147a5bcf542045e1f4a065778 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 24 Oct 2022 19:56:26 -0700 Subject: [PATCH] Harrison/add natbot (#18) --- examples/natbot.py | 88 ++++++ langchain/chains/natbot/__init__.py | 4 + langchain/chains/natbot/base.py | 65 ++++ langchain/chains/natbot/crawler.py | 408 +++++++++++++++++++++++++ langchain/chains/natbot/prompt.py | 144 +++++++++ tests/unit_tests/chains/test_natbot.py | 39 +++ tests/unit_tests/llms/fake_llm.py | 2 +- 7 files changed, 749 insertions(+), 1 deletion(-) create mode 100644 examples/natbot.py create mode 100644 langchain/chains/natbot/__init__.py create mode 100644 langchain/chains/natbot/base.py create mode 100644 langchain/chains/natbot/crawler.py create mode 100644 langchain/chains/natbot/prompt.py create mode 100644 tests/unit_tests/chains/test_natbot.py diff --git a/examples/natbot.py b/examples/natbot.py new file mode 100644 index 00000000..9f380bf7 --- /dev/null +++ b/examples/natbot.py @@ -0,0 +1,88 @@ +"""Run NatBot.""" +import time + +from langchain.chains.natbot.base import NatBotChain +from langchain.chains.natbot.crawler import Crawler # type: ignore + + +def run_cmd(cmd: str, _crawler: Crawler) -> None: + """Run command.""" + cmd = cmd.split("\n")[0] + + if cmd.startswith("SCROLL UP"): + _crawler.scroll("up") + elif cmd.startswith("SCROLL DOWN"): + _crawler.scroll("down") + elif cmd.startswith("CLICK"): + commasplit = cmd.split(",") + id = commasplit[0].split(" ")[1] + _crawler.click(id) + elif cmd.startswith("TYPE"): + spacesplit = cmd.split(" ") + id = spacesplit[1] + text_pieces = spacesplit[2:] + text = " ".join(text_pieces) + # Strip leading and trailing double quotes + text = text[1:-1] + + if cmd.startswith("TYPESUBMIT"): + text += "\n" + _crawler.type(id, text) + + time.sleep(2) + + +if __name__ == "__main__": + + objective = "Make a reservation for 2 at 7pm at bistro vida in menlo park" + print("\nWelcome to natbot! What is your objective?") + i = input() + if len(i) > 0: + objective = i + quiet = False + nat_bot_chain = NatBotChain.from_default(objective) + _crawler = Crawler() + _crawler.go_to_page("google.com") + try: + while True: + browser_content = "\n".join(_crawler.crawl()) + llm_command = nat_bot_chain.run(_crawler.page.url, browser_content) + if not quiet: + print("URL: " + _crawler.page.url) + print("Objective: " + objective) + print("----------------\n" + browser_content + "\n----------------\n") + if len(llm_command) > 0: + print("Suggested command: " + llm_command) + + command = input() + if command == "r" or command == "": + run_cmd(llm_command, _crawler) + elif command == "g": + url = input("URL:") + _crawler.go_to_page(url) + elif command == "u": + _crawler.scroll("up") + time.sleep(1) + elif command == "d": + _crawler.scroll("down") + time.sleep(1) + elif command == "c": + id = input("id:") + _crawler.click(id) + time.sleep(1) + elif command == "t": + id = input("id:") + text = input("text:") + _crawler.type(id, text) + time.sleep(1) + elif command == "o": + objective = input("Objective:") + else: + print( + "(g) to visit url\n(u) scroll up\n(d) scroll down\n(c) to click" + "\n(t) to type\n(h) to view commands again" + "\n(r/enter) to run suggested command\n(o) change objective" + ) + except KeyboardInterrupt: + print("\n[!] Ctrl+C detected, exiting gracefully.") + exit(0) diff --git a/langchain/chains/natbot/__init__.py b/langchain/chains/natbot/__init__.py new file mode 100644 index 00000000..45a2231a --- /dev/null +++ b/langchain/chains/natbot/__init__.py @@ -0,0 +1,4 @@ +"""Implement a GPT-3 driven browser. + +Heavily influenced from https://github.com/nat/natbot +""" diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py new file mode 100644 index 00000000..5c2435f3 --- /dev/null +++ b/langchain/chains/natbot/base.py @@ -0,0 +1,65 @@ +"""Implement a GPT-3 driven browser.""" +from typing import Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.natbot.prompt import PROMPT +from langchain.llms.base import LLM +from langchain.llms.openai import OpenAI + + +class NatBotChain(Chain, BaseModel): + """Implement a GPT-3 driven browser.""" + + llm: LLM + objective: str + input_url_key: str = "url" + input_browser_content_key: str = "browser_content" + previous_command: str = "" + output_key: str = "command" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @classmethod + def from_default(cls, objective: str) -> "NatBotChain": + """Load with default LLM.""" + llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50) + return cls(llm=llm, objective=objective) + + @property + def input_keys(self) -> List[str]: + """Expect url and browser content.""" + return [self.input_url_key, self.input_browser_content_key] + + @property + def output_keys(self) -> List[str]: + """Return command.""" + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) + url = inputs[self.input_url_key] + browser_content = inputs[self.input_browser_content_key] + llm_cmd = llm_executor.predict( + objective=self.objective, + url=url[:100], + previous_command=self.previous_command, + browser_content=browser_content[:4500], + ) + llm_cmd = llm_cmd.strip() + self.previous_command = llm_cmd + return {self.output_key: llm_cmd} + + def run(self, url: str, browser_content: str) -> str: + """More user-friendly interface for interfacing with natbot.""" + _inputs = { + self.input_url_key: url, + self.input_browser_content_key: browser_content, + } + return self(_inputs)[self.output_key] diff --git a/langchain/chains/natbot/crawler.py b/langchain/chains/natbot/crawler.py new file mode 100644 index 00000000..341b890b --- /dev/null +++ b/langchain/chains/natbot/crawler.py @@ -0,0 +1,408 @@ +# flake8: noqa +# type: ignore +import time +from sys import platform + +black_listed_elements = { + "html", + "head", + "title", + "meta", + "iframe", + "body", + "script", + "style", + "path", + "svg", + "br", + "::marker", +} + + +class Crawler: + def __init__(self): + try: + from playwright.sync_api import sync_playwright + except ImportError: + raise ValueError( + "Could not import playwright python package. " + "Please it install it with `pip install playwright`." + ) + self.browser = ( + sync_playwright() + .start() + .chromium.launch( + headless=False, + ) + ) + + self.page = self.browser.new_page() + self.page.set_viewport_size({"width": 1280, "height": 1080}) + + def go_to_page(self, url): + self.page.goto(url=url if "://" in url else "http://" + url) + self.client = self.page.context.new_cdp_session(self.page) + self.page_element_buffer = {} + + def scroll(self, direction): + if direction == "up": + self.page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" + ) + elif direction == "down": + self.page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;" + ) + + def click(self, id): + # Inject javascript into the page which removes the target= attribute from all links + js = """ + links = document.getElementsByTagName("a"); + for (var i = 0; i < links.length; i++) { + links[i].removeAttribute("target"); + } + """ + self.page.evaluate(js) + + element = self.page_element_buffer.get(int(id)) + if element: + x = element.get("center_x") + y = element.get("center_y") + + self.page.mouse.click(x, y) + else: + print("Could not find element") + + def type(self, id, text): + self.click(id) + self.page.keyboard.type(text) + + def enter(self): + self.page.keyboard.press("Enter") + + def crawl(self): + page = self.page + page_element_buffer = self.page_element_buffer + start = time.time() + + page_state_as_text = [] + + device_pixel_ratio = page.evaluate("window.devicePixelRatio") + if platform == "darwin" and device_pixel_ratio == 1: # lies + device_pixel_ratio = 2 + + win_scroll_x = page.evaluate("window.scrollX") + win_scroll_y = page.evaluate("window.scrollY") + win_upper_bound = page.evaluate("window.pageYOffset") + win_left_bound = page.evaluate("window.pageXOffset") + win_width = page.evaluate("window.screen.width") + win_height = page.evaluate("window.screen.height") + win_right_bound = win_left_bound + win_width + win_lower_bound = win_upper_bound + win_height + document_offset_height = page.evaluate("document.body.offsetHeight") + document_scroll_height = page.evaluate("document.body.scrollHeight") + + # percentage_progress_start = (win_upper_bound / document_scroll_height) * 100 + # percentage_progress_end = ( + # (win_height + win_upper_bound) / document_scroll_height + # ) * 100 + percentage_progress_start = 1 + percentage_progress_end = 2 + + page_state_as_text.append( + { + "x": 0, + "y": 0, + "text": "[scrollbar {:0.2f}-{:0.2f}%]".format( + round(percentage_progress_start, 2), round(percentage_progress_end) + ), + } + ) + + tree = self.client.send( + "DOMSnapshot.captureSnapshot", + {"computedStyles": [], "includeDOMRects": True, "includePaintOrder": True}, + ) + strings = tree["strings"] + document = tree["documents"][0] + nodes = document["nodes"] + backend_node_id = nodes["backendNodeId"] + attributes = nodes["attributes"] + node_value = nodes["nodeValue"] + parent = nodes["parentIndex"] + node_types = nodes["nodeType"] + node_names = nodes["nodeName"] + is_clickable = set(nodes["isClickable"]["index"]) + + text_value = nodes["textValue"] + text_value_index = text_value["index"] + text_value_values = text_value["value"] + + input_value = nodes["inputValue"] + input_value_index = input_value["index"] + input_value_values = input_value["value"] + + input_checked = nodes["inputChecked"] + layout = document["layout"] + layout_node_index = layout["nodeIndex"] + bounds = layout["bounds"] + + cursor = 0 + html_elements_text = [] + + child_nodes = {} + elements_in_view_port = [] + + anchor_ancestry = {"-1": (False, None)} + button_ancestry = {"-1": (False, None)} + + def convert_name(node_name, has_click_handler): + if node_name == "a": + return "link" + if node_name == "input": + return "input" + if node_name == "img": + return "img" + if ( + node_name == "button" or has_click_handler + ): # found pages that needed this quirk + return "button" + else: + return "text" + + def find_attributes(attributes, keys): + values = {} + + for [key_index, value_index] in zip(*(iter(attributes),) * 2): + if value_index < 0: + continue + key = strings[key_index] + value = strings[value_index] + + if key in keys: + values[key] = value + keys.remove(key) + + if not keys: + return values + + return values + + def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id): + parent_id_str = str(parent_id) + if not parent_id_str in hash_tree: + parent_name = strings[node_names[parent_id]].lower() + grand_parent_id = parent[parent_id] + + add_to_hash_tree( + hash_tree, tag, parent_id, parent_name, grand_parent_id + ) + + is_parent_desc_anchor, anchor_id = hash_tree[parent_id_str] + + # even if the anchor is nested in another anchor, we set the "root" for all descendants to be ::Self + if node_name == tag: + value = (True, node_id) + elif ( + is_parent_desc_anchor + ): # reuse the parent's anchor_id (which could be much higher in the tree) + value = (True, anchor_id) + else: + value = ( + False, + None, + ) # not a descendant of an anchor, most likely it will become text, an interactive element or discarded + + hash_tree[str(node_id)] = value + + return value + + for index, node_name_index in enumerate(node_names): + node_parent = parent[index] + node_name = strings[node_name_index].lower() + + is_ancestor_of_anchor, anchor_id = add_to_hash_tree( + anchor_ancestry, "a", index, node_name, node_parent + ) + + is_ancestor_of_button, button_id = add_to_hash_tree( + button_ancestry, "button", index, node_name, node_parent + ) + + try: + cursor = layout_node_index.index( + index + ) # todo replace this with proper cursoring, ignoring the fact this is O(n^2) for the moment + except: + continue + + if node_name in black_listed_elements: + continue + + [x, y, width, height] = bounds[cursor] + x /= device_pixel_ratio + y /= device_pixel_ratio + width /= device_pixel_ratio + height /= device_pixel_ratio + + elem_left_bound = x + elem_top_bound = y + elem_right_bound = x + width + elem_lower_bound = y + height + + partially_is_in_viewport = ( + elem_left_bound < win_right_bound + and elem_right_bound >= win_left_bound + and elem_top_bound < win_lower_bound + and elem_lower_bound >= win_upper_bound + ) + + if not partially_is_in_viewport: + continue + + meta_data = [] + + # inefficient to grab the same set of keys for kinds of objects but its fine for now + element_attributes = find_attributes( + attributes[index], ["type", "placeholder", "aria-label", "title", "alt"] + ) + + ancestor_exception = is_ancestor_of_anchor or is_ancestor_of_button + ancestor_node_key = ( + None + if not ancestor_exception + else str(anchor_id) + if is_ancestor_of_anchor + else str(button_id) + ) + ancestor_node = ( + None + if not ancestor_exception + else child_nodes.setdefault(str(ancestor_node_key), []) + ) + + if node_name == "#text" and ancestor_exception: + text = strings[node_value[index]] + if text == "|" or text == "•": + continue + ancestor_node.append({"type": "type", "value": text}) + else: + if ( + node_name == "input" and element_attributes.get("type") == "submit" + ) or node_name == "button": + node_name = "button" + element_attributes.pop( + "type", None + ) # prevent [button ... (button)..] + + for key in element_attributes: + if ancestor_exception: + ancestor_node.append( + { + "type": "attribute", + "key": key, + "value": element_attributes[key], + } + ) + else: + meta_data.append(element_attributes[key]) + + element_node_value = None + + if node_value[index] >= 0: + element_node_value = strings[node_value[index]] + if ( + element_node_value == "|" + ): # commonly used as a seperator, does not add much context - lets save ourselves some token space + continue + elif ( + node_name == "input" + and index in input_value_index + and element_node_value is None + ): + node_input_text_index = input_value_index.index(index) + text_index = input_value_values[node_input_text_index] + if node_input_text_index >= 0 and text_index >= 0: + element_node_value = strings[text_index] + + # remove redudant elements + if ancestor_exception and (node_name != "a" and node_name != "button"): + continue + + elements_in_view_port.append( + { + "node_index": str(index), + "backend_node_id": backend_node_id[index], + "node_name": node_name, + "node_value": element_node_value, + "node_meta": meta_data, + "is_clickable": index in is_clickable, + "origin_x": int(x), + "origin_y": int(y), + "center_x": int(x + (width / 2)), + "center_y": int(y + (height / 2)), + } + ) + + # lets filter further to remove anything that does not hold any text nor has click handlers + merge text from leaf#text nodes with the parent + elements_of_interest = [] + id_counter = 0 + + for element in elements_in_view_port: + node_index = element.get("node_index") + node_name = element.get("node_name") + node_value = element.get("node_value") + is_clickable = element.get("is_clickable") + origin_x = element.get("origin_x") + origin_y = element.get("origin_y") + center_x = element.get("center_x") + center_y = element.get("center_y") + meta_data = element.get("node_meta") + + inner_text = f"{node_value} " if node_value else "" + meta = "" + + if node_index in child_nodes: + for child in child_nodes.get(node_index): + entry_type = child.get("type") + entry_value = child.get("value") + + if entry_type == "attribute": + entry_key = child.get("key") + meta_data.append(f'{entry_key}="{entry_value}"') + else: + inner_text += f"{entry_value} " + + if meta_data: + meta_string = " ".join(meta_data) + meta = f" {meta_string}" + + if inner_text != "": + inner_text = f"{inner_text.strip()}" + + converted_node_name = convert_name(node_name, is_clickable) + + # not very elegant, more like a placeholder + if ( + (converted_node_name != "button" or meta == "") + and converted_node_name != "link" + and converted_node_name != "input" + and converted_node_name != "img" + and converted_node_name != "textarea" + ) and inner_text.strip() == "": + continue + + page_element_buffer[id_counter] = element + + if inner_text != "": + elements_of_interest.append( + f"""<{converted_node_name} id={id_counter}{meta}>{inner_text}""" + ) + else: + elements_of_interest.append( + f"""<{converted_node_name} id={id_counter}{meta}/>""" + ) + id_counter += 1 + + print("Parsing time: {:0.2f} seconds".format(time.time() - start)) + return elements_of_interest diff --git a/langchain/chains/natbot/prompt.py b/langchain/chains/natbot/prompt.py new file mode 100644 index 00000000..390a532d --- /dev/null +++ b/langchain/chains/natbot/prompt.py @@ -0,0 +1,144 @@ +# flake8: noqa +from langchain.prompt import Prompt + +_PROMPT_TEMPLATE = """ +You are an agent controlling a browser. You are given: + + (1) an objective that you are trying to achieve + (2) the URL of your current web page + (3) a simplified text description of what's visible in the browser window (more on that below) + +You can issue these commands: + SCROLL UP - scroll up one page + SCROLL DOWN - scroll down one page + CLICK X - click on a given element. You can only click on links, buttons, and inputs! + TYPE X "TEXT" - type the specified text into the input with id X + TYPESUBMIT X "TEXT" - same as TYPE above, except then it presses ENTER to submit the form + +The format of the browser content is highly simplified; all formatting elements are stripped. +Interactive elements such as links, inputs, buttons are represented like this: + + text + + text + +Images are rendered as their alt text like this: + + + +Based on your given objective, issue whatever command you believe will get you closest to achieving your goal. +You always start on Google; you should submit a search query to Google that will take you to the best page for +achieving your objective. And then interact with that page to achieve your objective. + +If you find yourself on Google and there are no search results displayed yet, you should probably issue a command +like "TYPESUBMIT 7 "search query"" to get to a more useful page. + +Then, if you find yourself on a Google search results page, you might issue the command "CLICK 24" to click +on the first link in the search results. (If your previous command was a TYPESUBMIT your next command should +probably be a CLICK.) + +Don't try to interact with elements that you can't see. + +Here are some examples: + +EXAMPLE 1: +================================================== +CURRENT BROWSER CONTENT: +------------------ +About +Store +Gmail +Images +(Google apps) +Sign in +(Google) + + + + +Advertising +Business +How Search works +Carbon neutral since 2007 +Privacy +Terms +Settings +------------------ +OBJECTIVE: Find a 2 bedroom house for sale in Anchorage AK for under $750k +CURRENT URL: https://www.google.com/ +YOUR COMMAND: +TYPESUBMIT 8 "anchorage redfin" +================================================== + +EXAMPLE 2: +================================================== +CURRENT BROWSER CONTENT: +------------------ +About +Store +Gmail +Images +(Google apps) +Sign in +(Google) + + + + +Advertising +Business +How Search works +Carbon neutral since 2007 +Privacy +Terms +Settings +------------------ +OBJECTIVE: Make a reservation for 4 at Dorsia at 8pm +CURRENT URL: https://www.google.com/ +YOUR COMMAND: +TYPESUBMIT 8 "dorsia nyc opentable" +================================================== + +EXAMPLE 3: +================================================== +CURRENT BROWSER CONTENT: +------------------ + + + + +OpenTable logo + +Find your table for any occasion + +Sep 28, 2022 +7:00 PM +2 people + + +It looks like you're in Peninsula. Not correct? + + +------------------ +OBJECTIVE: Make a reservation for 4 for dinner at Dorsia in New York City at 8pm +CURRENT URL: https://www.opentable.com/ +YOUR COMMAND: +TYPESUBMIT 12 "dorsia new york city" +================================================== + +The current browser content, objective, and current URL follow. Reply with your next command to the browser. + +CURRENT BROWSER CONTENT: +------------------ +{browser_content} +------------------ + +OBJECTIVE: {objective} +CURRENT URL: {url} +PREVIOUS COMMAND: {previous_command} +YOUR COMMAND: +""" +PROMPT = Prompt( + input_variables=["browser_content", "url", "previous_command", "objective"], + template=_PROMPT_TEMPLATE, +) diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py new file mode 100644 index 00000000..d2701f8b --- /dev/null +++ b/tests/unit_tests/chains/test_natbot.py @@ -0,0 +1,39 @@ +"""Test functionality related to natbot.""" + +from typing import List, Optional + +from langchain.chains.natbot.base import NatBotChain +from langchain.llms.base import LLM + + +class FakeLLM(LLM): + """Fake LLM wrapper for testing purposes.""" + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Return `foo` if longer than 10000 words, else `bar`.""" + if len(prompt) > 10000: + return "foo" + else: + return "bar" + + +def test_proper_inputs() -> None: + """Test that natbot shortens inputs correctly.""" + nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing") + url = "foo" * 10000 + browser_content = "foo" * 10000 + output = nat_bot_chain.run(url, browser_content) + assert output == "bar" + + +def test_variable_key_naming() -> None: + """Test that natbot handles variable key naming correctly.""" + nat_bot_chain = NatBotChain( + llm=FakeLLM(), + objective="testing", + input_url_key="u", + input_browser_content_key="b", + output_key="c", + ) + output = nat_bot_chain.run("foo", "foo") + assert output == "bar" diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index 60e3d871..f9d6387b 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -12,7 +12,7 @@ class FakeLLM(LLM): self._queries = queries def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: - """Return `foo` if no stop words, otherwise `bar`.""" + """First try to lookup in queries, else return 'foo' or 'bar'.""" if self._queries is not None: return self._queries[prompt] if stop is None: