diff --git a/examples/natbot.py b/examples/natbot.py
new file mode 100644
index 00000000..9f380bf7
--- /dev/null
+++ b/examples/natbot.py
@@ -0,0 +1,88 @@
+"""Run NatBot."""
+import time
+
+from langchain.chains.natbot.base import NatBotChain
+from langchain.chains.natbot.crawler import Crawler # type: ignore
+
+
+def run_cmd(cmd: str, _crawler: Crawler) -> None:
+ """Run command."""
+ cmd = cmd.split("\n")[0]
+
+ if cmd.startswith("SCROLL UP"):
+ _crawler.scroll("up")
+ elif cmd.startswith("SCROLL DOWN"):
+ _crawler.scroll("down")
+ elif cmd.startswith("CLICK"):
+ commasplit = cmd.split(",")
+ id = commasplit[0].split(" ")[1]
+ _crawler.click(id)
+ elif cmd.startswith("TYPE"):
+ spacesplit = cmd.split(" ")
+ id = spacesplit[1]
+ text_pieces = spacesplit[2:]
+ text = " ".join(text_pieces)
+ # Strip leading and trailing double quotes
+ text = text[1:-1]
+
+ if cmd.startswith("TYPESUBMIT"):
+ text += "\n"
+ _crawler.type(id, text)
+
+ time.sleep(2)
+
+
+if __name__ == "__main__":
+
+ objective = "Make a reservation for 2 at 7pm at bistro vida in menlo park"
+ print("\nWelcome to natbot! What is your objective?")
+ i = input()
+ if len(i) > 0:
+ objective = i
+ quiet = False
+ nat_bot_chain = NatBotChain.from_default(objective)
+ _crawler = Crawler()
+ _crawler.go_to_page("google.com")
+ try:
+ while True:
+ browser_content = "\n".join(_crawler.crawl())
+ llm_command = nat_bot_chain.run(_crawler.page.url, browser_content)
+ if not quiet:
+ print("URL: " + _crawler.page.url)
+ print("Objective: " + objective)
+ print("----------------\n" + browser_content + "\n----------------\n")
+ if len(llm_command) > 0:
+ print("Suggested command: " + llm_command)
+
+ command = input()
+ if command == "r" or command == "":
+ run_cmd(llm_command, _crawler)
+ elif command == "g":
+ url = input("URL:")
+ _crawler.go_to_page(url)
+ elif command == "u":
+ _crawler.scroll("up")
+ time.sleep(1)
+ elif command == "d":
+ _crawler.scroll("down")
+ time.sleep(1)
+ elif command == "c":
+ id = input("id:")
+ _crawler.click(id)
+ time.sleep(1)
+ elif command == "t":
+ id = input("id:")
+ text = input("text:")
+ _crawler.type(id, text)
+ time.sleep(1)
+ elif command == "o":
+ objective = input("Objective:")
+ else:
+ print(
+ "(g) to visit url\n(u) scroll up\n(d) scroll down\n(c) to click"
+ "\n(t) to type\n(h) to view commands again"
+ "\n(r/enter) to run suggested command\n(o) change objective"
+ )
+ except KeyboardInterrupt:
+ print("\n[!] Ctrl+C detected, exiting gracefully.")
+ exit(0)
diff --git a/langchain/chains/natbot/__init__.py b/langchain/chains/natbot/__init__.py
new file mode 100644
index 00000000..45a2231a
--- /dev/null
+++ b/langchain/chains/natbot/__init__.py
@@ -0,0 +1,4 @@
+"""Implement a GPT-3 driven browser.
+
+Heavily influenced from https://github.com/nat/natbot
+"""
diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py
new file mode 100644
index 00000000..5c2435f3
--- /dev/null
+++ b/langchain/chains/natbot/base.py
@@ -0,0 +1,65 @@
+"""Implement a GPT-3 driven browser."""
+from typing import Dict, List
+
+from pydantic import BaseModel, Extra
+
+from langchain.chains.base import Chain
+from langchain.chains.llm import LLMChain
+from langchain.chains.natbot.prompt import PROMPT
+from langchain.llms.base import LLM
+from langchain.llms.openai import OpenAI
+
+
+class NatBotChain(Chain, BaseModel):
+ """Implement a GPT-3 driven browser."""
+
+ llm: LLM
+ objective: str
+ input_url_key: str = "url"
+ input_browser_content_key: str = "browser_content"
+ previous_command: str = ""
+ output_key: str = "command"
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @classmethod
+ def from_default(cls, objective: str) -> "NatBotChain":
+ """Load with default LLM."""
+ llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
+ return cls(llm=llm, objective=objective)
+
+ @property
+ def input_keys(self) -> List[str]:
+ """Expect url and browser content."""
+ return [self.input_url_key, self.input_browser_content_key]
+
+ @property
+ def output_keys(self) -> List[str]:
+ """Return command."""
+ return [self.output_key]
+
+ def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
+ llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
+ url = inputs[self.input_url_key]
+ browser_content = inputs[self.input_browser_content_key]
+ llm_cmd = llm_executor.predict(
+ objective=self.objective,
+ url=url[:100],
+ previous_command=self.previous_command,
+ browser_content=browser_content[:4500],
+ )
+ llm_cmd = llm_cmd.strip()
+ self.previous_command = llm_cmd
+ return {self.output_key: llm_cmd}
+
+ def run(self, url: str, browser_content: str) -> str:
+ """More user-friendly interface for interfacing with natbot."""
+ _inputs = {
+ self.input_url_key: url,
+ self.input_browser_content_key: browser_content,
+ }
+ return self(_inputs)[self.output_key]
diff --git a/langchain/chains/natbot/crawler.py b/langchain/chains/natbot/crawler.py
new file mode 100644
index 00000000..341b890b
--- /dev/null
+++ b/langchain/chains/natbot/crawler.py
@@ -0,0 +1,408 @@
+# flake8: noqa
+# type: ignore
+import time
+from sys import platform
+
+black_listed_elements = {
+ "html",
+ "head",
+ "title",
+ "meta",
+ "iframe",
+ "body",
+ "script",
+ "style",
+ "path",
+ "svg",
+ "br",
+ "::marker",
+}
+
+
+class Crawler:
+ def __init__(self):
+ try:
+ from playwright.sync_api import sync_playwright
+ except ImportError:
+ raise ValueError(
+ "Could not import playwright python package. "
+ "Please it install it with `pip install playwright`."
+ )
+ self.browser = (
+ sync_playwright()
+ .start()
+ .chromium.launch(
+ headless=False,
+ )
+ )
+
+ self.page = self.browser.new_page()
+ self.page.set_viewport_size({"width": 1280, "height": 1080})
+
+ def go_to_page(self, url):
+ self.page.goto(url=url if "://" in url else "http://" + url)
+ self.client = self.page.context.new_cdp_session(self.page)
+ self.page_element_buffer = {}
+
+ def scroll(self, direction):
+ if direction == "up":
+ self.page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
+ )
+ elif direction == "down":
+ self.page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
+ )
+
+ def click(self, id):
+ # Inject javascript into the page which removes the target= attribute from all links
+ js = """
+ links = document.getElementsByTagName("a");
+ for (var i = 0; i < links.length; i++) {
+ links[i].removeAttribute("target");
+ }
+ """
+ self.page.evaluate(js)
+
+ element = self.page_element_buffer.get(int(id))
+ if element:
+ x = element.get("center_x")
+ y = element.get("center_y")
+
+ self.page.mouse.click(x, y)
+ else:
+ print("Could not find element")
+
+ def type(self, id, text):
+ self.click(id)
+ self.page.keyboard.type(text)
+
+ def enter(self):
+ self.page.keyboard.press("Enter")
+
+ def crawl(self):
+ page = self.page
+ page_element_buffer = self.page_element_buffer
+ start = time.time()
+
+ page_state_as_text = []
+
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ if platform == "darwin" and device_pixel_ratio == 1: # lies
+ device_pixel_ratio = 2
+
+ win_scroll_x = page.evaluate("window.scrollX")
+ win_scroll_y = page.evaluate("window.scrollY")
+ win_upper_bound = page.evaluate("window.pageYOffset")
+ win_left_bound = page.evaluate("window.pageXOffset")
+ win_width = page.evaluate("window.screen.width")
+ win_height = page.evaluate("window.screen.height")
+ win_right_bound = win_left_bound + win_width
+ win_lower_bound = win_upper_bound + win_height
+ document_offset_height = page.evaluate("document.body.offsetHeight")
+ document_scroll_height = page.evaluate("document.body.scrollHeight")
+
+ # percentage_progress_start = (win_upper_bound / document_scroll_height) * 100
+ # percentage_progress_end = (
+ # (win_height + win_upper_bound) / document_scroll_height
+ # ) * 100
+ percentage_progress_start = 1
+ percentage_progress_end = 2
+
+ page_state_as_text.append(
+ {
+ "x": 0,
+ "y": 0,
+ "text": "[scrollbar {:0.2f}-{:0.2f}%]".format(
+ round(percentage_progress_start, 2), round(percentage_progress_end)
+ ),
+ }
+ )
+
+ tree = self.client.send(
+ "DOMSnapshot.captureSnapshot",
+ {"computedStyles": [], "includeDOMRects": True, "includePaintOrder": True},
+ )
+ strings = tree["strings"]
+ document = tree["documents"][0]
+ nodes = document["nodes"]
+ backend_node_id = nodes["backendNodeId"]
+ attributes = nodes["attributes"]
+ node_value = nodes["nodeValue"]
+ parent = nodes["parentIndex"]
+ node_types = nodes["nodeType"]
+ node_names = nodes["nodeName"]
+ is_clickable = set(nodes["isClickable"]["index"])
+
+ text_value = nodes["textValue"]
+ text_value_index = text_value["index"]
+ text_value_values = text_value["value"]
+
+ input_value = nodes["inputValue"]
+ input_value_index = input_value["index"]
+ input_value_values = input_value["value"]
+
+ input_checked = nodes["inputChecked"]
+ layout = document["layout"]
+ layout_node_index = layout["nodeIndex"]
+ bounds = layout["bounds"]
+
+ cursor = 0
+ html_elements_text = []
+
+ child_nodes = {}
+ elements_in_view_port = []
+
+ anchor_ancestry = {"-1": (False, None)}
+ button_ancestry = {"-1": (False, None)}
+
+ def convert_name(node_name, has_click_handler):
+ if node_name == "a":
+ return "link"
+ if node_name == "input":
+ return "input"
+ if node_name == "img":
+ return "img"
+ if (
+ node_name == "button" or has_click_handler
+ ): # found pages that needed this quirk
+ return "button"
+ else:
+ return "text"
+
+ def find_attributes(attributes, keys):
+ values = {}
+
+ for [key_index, value_index] in zip(*(iter(attributes),) * 2):
+ if value_index < 0:
+ continue
+ key = strings[key_index]
+ value = strings[value_index]
+
+ if key in keys:
+ values[key] = value
+ keys.remove(key)
+
+ if not keys:
+ return values
+
+ return values
+
+ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
+ parent_id_str = str(parent_id)
+ if not parent_id_str in hash_tree:
+ parent_name = strings[node_names[parent_id]].lower()
+ grand_parent_id = parent[parent_id]
+
+ add_to_hash_tree(
+ hash_tree, tag, parent_id, parent_name, grand_parent_id
+ )
+
+ is_parent_desc_anchor, anchor_id = hash_tree[parent_id_str]
+
+ # even if the anchor is nested in another anchor, we set the "root" for all descendants to be ::Self
+ if node_name == tag:
+ value = (True, node_id)
+ elif (
+ is_parent_desc_anchor
+ ): # reuse the parent's anchor_id (which could be much higher in the tree)
+ value = (True, anchor_id)
+ else:
+ value = (
+ False,
+ None,
+ ) # not a descendant of an anchor, most likely it will become text, an interactive element or discarded
+
+ hash_tree[str(node_id)] = value
+
+ return value
+
+ for index, node_name_index in enumerate(node_names):
+ node_parent = parent[index]
+ node_name = strings[node_name_index].lower()
+
+ is_ancestor_of_anchor, anchor_id = add_to_hash_tree(
+ anchor_ancestry, "a", index, node_name, node_parent
+ )
+
+ is_ancestor_of_button, button_id = add_to_hash_tree(
+ button_ancestry, "button", index, node_name, node_parent
+ )
+
+ try:
+ cursor = layout_node_index.index(
+ index
+ ) # todo replace this with proper cursoring, ignoring the fact this is O(n^2) for the moment
+ except:
+ continue
+
+ if node_name in black_listed_elements:
+ continue
+
+ [x, y, width, height] = bounds[cursor]
+ x /= device_pixel_ratio
+ y /= device_pixel_ratio
+ width /= device_pixel_ratio
+ height /= device_pixel_ratio
+
+ elem_left_bound = x
+ elem_top_bound = y
+ elem_right_bound = x + width
+ elem_lower_bound = y + height
+
+ partially_is_in_viewport = (
+ elem_left_bound < win_right_bound
+ and elem_right_bound >= win_left_bound
+ and elem_top_bound < win_lower_bound
+ and elem_lower_bound >= win_upper_bound
+ )
+
+ if not partially_is_in_viewport:
+ continue
+
+ meta_data = []
+
+ # inefficient to grab the same set of keys for kinds of objects but its fine for now
+ element_attributes = find_attributes(
+ attributes[index], ["type", "placeholder", "aria-label", "title", "alt"]
+ )
+
+ ancestor_exception = is_ancestor_of_anchor or is_ancestor_of_button
+ ancestor_node_key = (
+ None
+ if not ancestor_exception
+ else str(anchor_id)
+ if is_ancestor_of_anchor
+ else str(button_id)
+ )
+ ancestor_node = (
+ None
+ if not ancestor_exception
+ else child_nodes.setdefault(str(ancestor_node_key), [])
+ )
+
+ if node_name == "#text" and ancestor_exception:
+ text = strings[node_value[index]]
+ if text == "|" or text == "•":
+ continue
+ ancestor_node.append({"type": "type", "value": text})
+ else:
+ if (
+ node_name == "input" and element_attributes.get("type") == "submit"
+ ) or node_name == "button":
+ node_name = "button"
+ element_attributes.pop(
+ "type", None
+ ) # prevent [button ... (button)..]
+
+ for key in element_attributes:
+ if ancestor_exception:
+ ancestor_node.append(
+ {
+ "type": "attribute",
+ "key": key,
+ "value": element_attributes[key],
+ }
+ )
+ else:
+ meta_data.append(element_attributes[key])
+
+ element_node_value = None
+
+ if node_value[index] >= 0:
+ element_node_value = strings[node_value[index]]
+ if (
+ element_node_value == "|"
+ ): # commonly used as a seperator, does not add much context - lets save ourselves some token space
+ continue
+ elif (
+ node_name == "input"
+ and index in input_value_index
+ and element_node_value is None
+ ):
+ node_input_text_index = input_value_index.index(index)
+ text_index = input_value_values[node_input_text_index]
+ if node_input_text_index >= 0 and text_index >= 0:
+ element_node_value = strings[text_index]
+
+ # remove redudant elements
+ if ancestor_exception and (node_name != "a" and node_name != "button"):
+ continue
+
+ elements_in_view_port.append(
+ {
+ "node_index": str(index),
+ "backend_node_id": backend_node_id[index],
+ "node_name": node_name,
+ "node_value": element_node_value,
+ "node_meta": meta_data,
+ "is_clickable": index in is_clickable,
+ "origin_x": int(x),
+ "origin_y": int(y),
+ "center_x": int(x + (width / 2)),
+ "center_y": int(y + (height / 2)),
+ }
+ )
+
+ # lets filter further to remove anything that does not hold any text nor has click handlers + merge text from leaf#text nodes with the parent
+ elements_of_interest = []
+ id_counter = 0
+
+ for element in elements_in_view_port:
+ node_index = element.get("node_index")
+ node_name = element.get("node_name")
+ node_value = element.get("node_value")
+ is_clickable = element.get("is_clickable")
+ origin_x = element.get("origin_x")
+ origin_y = element.get("origin_y")
+ center_x = element.get("center_x")
+ center_y = element.get("center_y")
+ meta_data = element.get("node_meta")
+
+ inner_text = f"{node_value} " if node_value else ""
+ meta = ""
+
+ if node_index in child_nodes:
+ for child in child_nodes.get(node_index):
+ entry_type = child.get("type")
+ entry_value = child.get("value")
+
+ if entry_type == "attribute":
+ entry_key = child.get("key")
+ meta_data.append(f'{entry_key}="{entry_value}"')
+ else:
+ inner_text += f"{entry_value} "
+
+ if meta_data:
+ meta_string = " ".join(meta_data)
+ meta = f" {meta_string}"
+
+ if inner_text != "":
+ inner_text = f"{inner_text.strip()}"
+
+ converted_node_name = convert_name(node_name, is_clickable)
+
+ # not very elegant, more like a placeholder
+ if (
+ (converted_node_name != "button" or meta == "")
+ and converted_node_name != "link"
+ and converted_node_name != "input"
+ and converted_node_name != "img"
+ and converted_node_name != "textarea"
+ ) and inner_text.strip() == "":
+ continue
+
+ page_element_buffer[id_counter] = element
+
+ if inner_text != "":
+ elements_of_interest.append(
+ f"""<{converted_node_name} id={id_counter}{meta}>{inner_text}{converted_node_name}>"""
+ )
+ else:
+ elements_of_interest.append(
+ f"""<{converted_node_name} id={id_counter}{meta}/>"""
+ )
+ id_counter += 1
+
+ print("Parsing time: {:0.2f} seconds".format(time.time() - start))
+ return elements_of_interest
diff --git a/langchain/chains/natbot/prompt.py b/langchain/chains/natbot/prompt.py
new file mode 100644
index 00000000..390a532d
--- /dev/null
+++ b/langchain/chains/natbot/prompt.py
@@ -0,0 +1,144 @@
+# flake8: noqa
+from langchain.prompt import Prompt
+
+_PROMPT_TEMPLATE = """
+You are an agent controlling a browser. You are given:
+
+ (1) an objective that you are trying to achieve
+ (2) the URL of your current web page
+ (3) a simplified text description of what's visible in the browser window (more on that below)
+
+You can issue these commands:
+ SCROLL UP - scroll up one page
+ SCROLL DOWN - scroll down one page
+ CLICK X - click on a given element. You can only click on links, buttons, and inputs!
+ TYPE X "TEXT" - type the specified text into the input with id X
+ TYPESUBMIT X "TEXT" - same as TYPE above, except then it presses ENTER to submit the form
+
+The format of the browser content is highly simplified; all formatting elements are stripped.
+Interactive elements such as links, inputs, buttons are represented like this:
+
+ text
+
+ text
+
+Images are rendered as their alt text like this:
+
+
+
+Based on your given objective, issue whatever command you believe will get you closest to achieving your goal.
+You always start on Google; you should submit a search query to Google that will take you to the best page for
+achieving your objective. And then interact with that page to achieve your objective.
+
+If you find yourself on Google and there are no search results displayed yet, you should probably issue a command
+like "TYPESUBMIT 7 "search query"" to get to a more useful page.
+
+Then, if you find yourself on a Google search results page, you might issue the command "CLICK 24" to click
+on the first link in the search results. (If your previous command was a TYPESUBMIT your next command should
+probably be a CLICK.)
+
+Don't try to interact with elements that you can't see.
+
+Here are some examples:
+
+EXAMPLE 1:
+==================================================
+CURRENT BROWSER CONTENT:
+------------------
+About
+Store
+Gmail
+Images
+(Google apps)
+Sign in
+
+
+
+
+
+Advertising
+Business
+How Search works
+Carbon neutral since 2007
+Privacy
+Terms
+Settings
+------------------
+OBJECTIVE: Find a 2 bedroom house for sale in Anchorage AK for under $750k
+CURRENT URL: https://www.google.com/
+YOUR COMMAND:
+TYPESUBMIT 8 "anchorage redfin"
+==================================================
+
+EXAMPLE 2:
+==================================================
+CURRENT BROWSER CONTENT:
+------------------
+About
+Store
+Gmail
+Images
+(Google apps)
+Sign in
+
+
+
+
+
+Advertising
+Business
+How Search works
+Carbon neutral since 2007
+Privacy
+Terms
+Settings
+------------------
+OBJECTIVE: Make a reservation for 4 at Dorsia at 8pm
+CURRENT URL: https://www.google.com/
+YOUR COMMAND:
+TYPESUBMIT 8 "dorsia nyc opentable"
+==================================================
+
+EXAMPLE 3:
+==================================================
+CURRENT BROWSER CONTENT:
+------------------
+
+
+
+
+OpenTable logo
+
+Find your table for any occasion
+
+Sep 28, 2022
+7:00 PM
+2 people
+
+
+It looks like you're in Peninsula. Not correct?
+
+
+------------------
+OBJECTIVE: Make a reservation for 4 for dinner at Dorsia in New York City at 8pm
+CURRENT URL: https://www.opentable.com/
+YOUR COMMAND:
+TYPESUBMIT 12 "dorsia new york city"
+==================================================
+
+The current browser content, objective, and current URL follow. Reply with your next command to the browser.
+
+CURRENT BROWSER CONTENT:
+------------------
+{browser_content}
+------------------
+
+OBJECTIVE: {objective}
+CURRENT URL: {url}
+PREVIOUS COMMAND: {previous_command}
+YOUR COMMAND:
+"""
+PROMPT = Prompt(
+ input_variables=["browser_content", "url", "previous_command", "objective"],
+ template=_PROMPT_TEMPLATE,
+)
diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py
new file mode 100644
index 00000000..d2701f8b
--- /dev/null
+++ b/tests/unit_tests/chains/test_natbot.py
@@ -0,0 +1,39 @@
+"""Test functionality related to natbot."""
+
+from typing import List, Optional
+
+from langchain.chains.natbot.base import NatBotChain
+from langchain.llms.base import LLM
+
+
+class FakeLLM(LLM):
+ """Fake LLM wrapper for testing purposes."""
+
+ def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
+ """Return `foo` if longer than 10000 words, else `bar`."""
+ if len(prompt) > 10000:
+ return "foo"
+ else:
+ return "bar"
+
+
+def test_proper_inputs() -> None:
+ """Test that natbot shortens inputs correctly."""
+ nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing")
+ url = "foo" * 10000
+ browser_content = "foo" * 10000
+ output = nat_bot_chain.run(url, browser_content)
+ assert output == "bar"
+
+
+def test_variable_key_naming() -> None:
+ """Test that natbot handles variable key naming correctly."""
+ nat_bot_chain = NatBotChain(
+ llm=FakeLLM(),
+ objective="testing",
+ input_url_key="u",
+ input_browser_content_key="b",
+ output_key="c",
+ )
+ output = nat_bot_chain.run("foo", "foo")
+ assert output == "bar"
diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py
index 60e3d871..f9d6387b 100644
--- a/tests/unit_tests/llms/fake_llm.py
+++ b/tests/unit_tests/llms/fake_llm.py
@@ -12,7 +12,7 @@ class FakeLLM(LLM):
self._queries = queries
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
- """Return `foo` if no stop words, otherwise `bar`."""
+ """First try to lookup in queries, else return 'foo' or 'bar'."""
if self._queries is not None:
return self._queries[prompt]
if stop is None: