From 18aeb720126a68201c7e3b5a617139c27c779496 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 24 Oct 2022 14:51:15 -0700 Subject: [PATCH] initial commit --- .flake8 | 11 ++ .github/workflows/lint.yml | 23 +++ .github/workflows/test.yml | 23 +++ .gitignore | 130 ++++++++++++++++ MANIFEST.in | 2 + Makefile | 17 +++ README.md | 80 ++++++++++ docs/Makefile | 21 +++ docs/conf.py | 65 ++++++++ docs/index.rst | 10 ++ docs/make.bat | 35 +++++ docs/modules/chains.rst | 7 + docs/modules/llms.rst | 6 + docs/modules/prompt.rst | 6 + docs/requirements.txt | 4 + examples/llm_math.ipynb | 59 ++++++++ examples/self_ask_with_search.ipynb | 74 +++++++++ examples/simple_prompts.ipynb | 64 ++++++++ langchain/VERSION | 1 + langchain/__init__.py | 27 ++++ langchain/chains/__init__.py | 14 ++ langchain/chains/base.py | 41 +++++ langchain/chains/llm.py | 46 ++++++ langchain/chains/llm_math/__init__.py | 4 + langchain/chains/llm_math/base.py | 57 +++++++ langchain/chains/llm_math/prompt.py | 38 +++++ langchain/chains/python.py | 40 +++++ .../chains/self_ask_with_search/__init__.py | 4 + langchain/chains/self_ask_with_search/base.py | 142 ++++++++++++++++++ .../chains/self_ask_with_search/prompt.py | 44 ++++++ langchain/chains/serpapi.py | 99 ++++++++++++ langchain/formatting.py | 32 ++++ langchain/llms/__init__.py | 5 + langchain/llms/base.py | 11 ++ langchain/llms/cohere.py | 72 +++++++++ langchain/llms/openai.py | 65 ++++++++ langchain/prompt.py | 47 ++++++ pyproject.toml | 7 + readthedocs.yml | 10 ++ requirements.txt | 9 ++ setup.py | 23 +++ test_requirements.txt | 3 + tests/__init__.py | 1 + tests/integration_tests/__init__.py | 1 + tests/integration_tests/chains/__init__.py | 1 + .../chains/test_self_ask_with_search.py | 18 +++ .../integration_tests/chains/test_serpapi.py | 9 ++ tests/integration_tests/llms/__init__.py | 1 + tests/integration_tests/llms/test_cohere.py | 10 ++ tests/integration_tests/llms/test_openai.py | 10 ++ tests/unit_tests/__init__.py | 1 + tests/unit_tests/chains/__init__.py | 1 + tests/unit_tests/chains/test_base.py | 50 ++++++ tests/unit_tests/chains/test_llm.py | 36 +++++ tests/unit_tests/chains/test_llm_math.py | 40 +++++ tests/unit_tests/chains/test_python.py | 15 ++ .../data/prompts/prompt_extra_args.json | 5 + .../data/prompts/prompt_missing_args.json | 3 + .../data/prompts/simple_prompt.json | 4 + tests/unit_tests/llms/__init__.py | 1 + tests/unit_tests/llms/fake_llm.py | 21 +++ tests/unit_tests/llms/test_cohere.py | 17 +++ tests/unit_tests/test_formatting.py | 26 ++++ tests/unit_tests/test_prompt.py | 47 ++++++ 64 files changed, 1796 insertions(+) create mode 100644 .flake8 create mode 100644 .github/workflows/lint.yml create mode 100644 .github/workflows/test.yml create mode 100644 .gitignore create mode 100644 MANIFEST.in create mode 100644 Makefile create mode 100644 README.md create mode 100644 docs/Makefile create mode 100644 docs/conf.py create mode 100644 docs/index.rst create mode 100644 docs/make.bat create mode 100644 docs/modules/chains.rst create mode 100644 docs/modules/llms.rst create mode 100644 docs/modules/prompt.rst create mode 100644 docs/requirements.txt create mode 100644 examples/llm_math.ipynb create mode 100644 examples/self_ask_with_search.ipynb create mode 100644 examples/simple_prompts.ipynb create mode 100644 langchain/VERSION create mode 100644 langchain/__init__.py create mode 100644 langchain/chains/__init__.py create mode 100644 langchain/chains/base.py create mode 100644 langchain/chains/llm.py create mode 100644 langchain/chains/llm_math/__init__.py create mode 100644 langchain/chains/llm_math/base.py create mode 100644 langchain/chains/llm_math/prompt.py create mode 100644 langchain/chains/python.py create mode 100644 langchain/chains/self_ask_with_search/__init__.py create mode 100644 langchain/chains/self_ask_with_search/base.py create mode 100644 langchain/chains/self_ask_with_search/prompt.py create mode 100644 langchain/chains/serpapi.py create mode 100644 langchain/formatting.py create mode 100644 langchain/llms/__init__.py create mode 100644 langchain/llms/base.py create mode 100644 langchain/llms/cohere.py create mode 100644 langchain/llms/openai.py create mode 100644 langchain/prompt.py create mode 100644 pyproject.toml create mode 100644 readthedocs.yml create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 test_requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/integration_tests/__init__.py create mode 100644 tests/integration_tests/chains/__init__.py create mode 100644 tests/integration_tests/chains/test_self_ask_with_search.py create mode 100644 tests/integration_tests/chains/test_serpapi.py create mode 100644 tests/integration_tests/llms/__init__.py create mode 100644 tests/integration_tests/llms/test_cohere.py create mode 100644 tests/integration_tests/llms/test_openai.py create mode 100644 tests/unit_tests/__init__.py create mode 100644 tests/unit_tests/chains/__init__.py create mode 100644 tests/unit_tests/chains/test_base.py create mode 100644 tests/unit_tests/chains/test_llm.py create mode 100644 tests/unit_tests/chains/test_llm_math.py create mode 100644 tests/unit_tests/chains/test_python.py create mode 100644 tests/unit_tests/data/prompts/prompt_extra_args.json create mode 100644 tests/unit_tests/data/prompts/prompt_missing_args.json create mode 100644 tests/unit_tests/data/prompts/simple_prompt.json create mode 100644 tests/unit_tests/llms/__init__.py create mode 100644 tests/unit_tests/llms/fake_llm.py create mode 100644 tests/unit_tests/llms/test_cohere.py create mode 100644 tests/unit_tests/test_formatting.py create mode 100644 tests/unit_tests/test_prompt.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..64a9cd4c --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +exclude = + .venv + __pycache__ + notebooks +# Recommend matching the black line length (default 88), +# rather than using the flake8 default of 79: +max-line-length = 88 +extend-ignore = + # See https://github.com/PyCQA/pycodestyle/issues/373 + E203, diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..34261d7b --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,23 @@ +name: lint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.7"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Analysing the code with our lint + run: | + make lint diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..4492af6f --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: test + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.7"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r test_requirements.txt + - name: Run unit tests + run: | + make tests diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..eb884bc5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,130 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..e322490e --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include langchain/VERSION +include LICENSE diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..091c8a67 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +.PHONY: format lint tests integration_tests + +format: + black . + isort . + +lint: + mypy . + black . --check + isort . --check + flake8 . + +tests: + pytest tests/unit_tests + +integration_tests: + pytest tests/integration_tests diff --git a/README.md b/README.md new file mode 100644 index 00000000..59bad6d4 --- /dev/null +++ b/README.md @@ -0,0 +1,80 @@ +# 🦜️🔗 LangChain + +⚡ Building applications with LLMs through composability ⚡ + +[![lint](https://github.com/hwchase17/langchain/actions/workflows/lint.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [![test](https://github.com/hwchase17/langchain/actions/workflows/test.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + + + +## Quick Install + +`pip install langchain` + +## 🤔 What is this? + +Large language models (LLMs) are emerging as a transformative technology, enabling +developers to build applications that they previously could not. +But using these LLMs in isolation is often not enough to +create a truly powerful app - the real power comes when you are able to +combine them with other sources of computation or knowledge. + +This library is aimed at assisting in the development of those types of applications. +It aims to create: +1. a comprehensive collection of pieces you would ever want to combine +2. a flexible interface for combining pieces into a single comprehensive "chain" +3. a schema for easily saving and sharing those chains + +## 🚀 What can I do with this + +This project was largely inspired by a few projects seen on Twitter for which we thought it would make sense to have more explicit tooling. A lot of the initial functionality was done in an attempt to recreate those. Those are: + +**[Self-ask-with-search](https://ofir.io/self-ask.pdf)** + +To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/self_ask_with_search.ipynb). + +``` +from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain + +llm = OpenAI(temperature=0) +search = SerpAPIChain() + +self_ask_with_search = SelfAskWithSearchChain(llm=llm, search_chain=search) + +self_ask_with_search.run("What is the hometown of the reigning men's U.S. Open champion?") +``` + +**[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)** + +To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/llm_math.ipynb). + +``` +from langchain import OpenAI, LLMMathChain + +llm = OpenAI(temperature=0) +llm_math = LLMMathChain(llm=llm) + +llm_math.run("How many of the integers between 0 and 99 inclusive are divisible by 8?") +``` + +**Generic Prompting** + +You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/simple_prompts.ipynb). + +``` +from langchain import Prompt, OpenAI, LLMChain + +template = """Question: {question} + +Answer: Let's think step by step.""" +prompt = Prompt(template=template, input_variables=["question"]) +llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0)) + +question = "What NFL team won the Super Bowl in the year Justin Beiber was born?" + +llm_chain.predict(question=question) +``` + +## 📖 Documentation + +The above examples are probably the most user friendly documentation that exists, +but full API docs can be found [here](https://langchain.readthedocs.io/en/latest/?). diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..c2f2a664 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,21 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SPHINXAUTOBUILD ?= sphinx-autobuild +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..97df7697 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,65 @@ +"""Configuration file for the Sphinx documentation builder.""" +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + +import langchain + +# -- Project information ----------------------------------------------------- + +project = "LangChain" +copyright = "2022, Harrison Chase" +author = "Harrison Chase" + +version = langchain.__version__ +release = langchain.__version__ + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autodoc.typehints", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", +] + +# autodoc_typehints = "signature" +autodoc_typehints = "description" + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" +# html_theme = "sphinx_typlog_theme" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 00000000..c6af00fd --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,10 @@ +Welcome to LangChain +========================== + +.. toctree:: + :maxdepth: 2 + :caption: User API + + modules/prompt + modules/llms + modules/chains diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..2119f510 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/modules/chains.rst b/docs/modules/chains.rst new file mode 100644 index 00000000..b3bc7781 --- /dev/null +++ b/docs/modules/chains.rst @@ -0,0 +1,7 @@ +:mod:`langchain.chains` +======================= + +.. automodule:: langchain.chains + :members: + :undoc-members: + diff --git a/docs/modules/llms.rst b/docs/modules/llms.rst new file mode 100644 index 00000000..bbb613ea --- /dev/null +++ b/docs/modules/llms.rst @@ -0,0 +1,6 @@ +:mod:`langchain.llms` +======================= + +.. automodule:: langchain.llms + :members: + :undoc-members: diff --git a/docs/modules/prompt.rst b/docs/modules/prompt.rst new file mode 100644 index 00000000..d6ae1784 --- /dev/null +++ b/docs/modules/prompt.rst @@ -0,0 +1,6 @@ +:mod:`langchain.prompt` +======================= + +.. automodule:: langchain.prompt + :members: + :undoc-members: diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..043d6f82 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx==4.5.0 +sphinx-autobuild==2021.3.14 +sphinx_rtd_theme==1.0.0 +sphinx-typlog-theme==0.8.0 diff --git a/examples/llm_math.ipynb b/examples/llm_math.ipynb new file mode 100644 index 00000000..f07479f5 --- /dev/null +++ b/examples/llm_math.ipynb @@ -0,0 +1,59 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "44e9ba31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Answer: 13\\n'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain import OpenAI, LLMMathChain\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "llm_math = LLMMathChain(llm=llm)\n", + "\n", + "llm_math.run(\"How many of the integers between 0 and 99 inclusive are divisible by 8?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f62f0c75", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/self_ask_with_search.ipynb b/examples/self_ask_with_search.ipynb new file mode 100644 index 00000000..b059bfc8 --- /dev/null +++ b/examples/self_ask_with_search.ipynb @@ -0,0 +1,74 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7e3b513e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "What is the hometown of the reigning men's U.S. Open champion?\n", + "Are follow up questions needed here:\u001b[102m Yes.\n", + "Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n", + "Intermediate answer: \u001b[106mCarlos Alcaraz\u001b[0m.\u001b[102m\n", + "Follow up: Where is Carlos Alcaraz from?\u001b[0m\n", + "Intermediate answer: \u001b[106mEl Palmar, Murcia, Spain\u001b[0m.\u001b[102m\n", + "So the final answer is: El Palmar, Murcia, Spain\u001b[0m" + ] + }, + { + "data": { + "text/plain": [ + "\"What is the hometown of the reigning men's U.S. Open champion?\\nAre follow up questions needed here: Yes.\\nFollow up: Who is the reigning men's U.S. Open champion?\\nIntermediate answer: Carlos Alcaraz.\\nFollow up: Where is Carlos Alcaraz from?\\nIntermediate answer: El Palmar, Murcia, Spain.\\nSo the final answer is: El Palmar, Murcia, Spain\"" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "search = SerpAPIChain()\n", + "\n", + "self_ask_with_search = SelfAskWithSearchChain(llm=llm, search_chain=search)\n", + "\n", + "self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6195fc82", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/simple_prompts.ipynb b/examples/simple_prompts.ipynb new file mode 100644 index 00000000..73d50fe9 --- /dev/null +++ b/examples/simple_prompts.ipynb @@ -0,0 +1,64 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "51a54c4d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "' The year Justin Beiber was born was 1994. In 1994, the Dallas Cowboys won the Super Bowl.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain import Prompt, OpenAI, LLMChain\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = Prompt(template=template, input_variables=[\"question\"])\n", + "llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0))\n", + "\n", + "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", + "\n", + "llm_chain.predict(question=question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03dd6918", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/VERSION b/langchain/VERSION new file mode 100644 index 00000000..8acdd82b --- /dev/null +++ b/langchain/VERSION @@ -0,0 +1 @@ +0.0.1 diff --git a/langchain/__init__.py b/langchain/__init__.py new file mode 100644 index 00000000..8f3733fe --- /dev/null +++ b/langchain/__init__.py @@ -0,0 +1,27 @@ +"""Main entrypoint into package.""" + +from pathlib import Path + +with open(Path(__file__).absolute().parents[0] / "VERSION") as _f: + __version__ = _f.read().strip() + +from langchain.chains import ( + LLMChain, + LLMMathChain, + PythonChain, + SelfAskWithSearchChain, + SerpAPIChain, +) +from langchain.llms import Cohere, OpenAI +from langchain.prompt import Prompt + +__all__ = [ + "LLMChain", + "LLMMathChain", + "PythonChain", + "SelfAskWithSearchChain", + "SerpAPIChain", + "Cohere", + "OpenAI", + "Prompt", +] diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py new file mode 100644 index 00000000..c77010b9 --- /dev/null +++ b/langchain/chains/__init__.py @@ -0,0 +1,14 @@ +"""Chains are easily reusable components which can be linked together.""" +from langchain.chains.llm import LLMChain +from langchain.chains.llm_math.base import LLMMathChain +from langchain.chains.python import PythonChain +from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain +from langchain.chains.serpapi import SerpAPIChain + +__all__ = [ + "LLMChain", + "LLMMathChain", + "PythonChain", + "SelfAskWithSearchChain", + "SerpAPIChain", +] diff --git a/langchain/chains/base.py b/langchain/chains/base.py new file mode 100644 index 00000000..4d200bd0 --- /dev/null +++ b/langchain/chains/base.py @@ -0,0 +1,41 @@ +"""Base interface that all chains should implement.""" +from abc import ABC, abstractmethod +from typing import Any, Dict, List + + +class Chain(ABC): + """Base interface that all chains should implement.""" + + @property + @abstractmethod + def input_keys(self) -> List[str]: + """Input keys this chain expects.""" + + @property + @abstractmethod + def output_keys(self) -> List[str]: + """Output keys this chain expects.""" + + def _validate_inputs(self, inputs: Dict[str, str]) -> None: + """Check that all inputs are present.""" + missing_keys = set(self.input_keys).difference(inputs) + if missing_keys: + raise ValueError(f"Missing some input keys: {missing_keys}") + + def _validate_outputs(self, outputs: Dict[str, str]) -> None: + if set(outputs) != set(self.output_keys): + raise ValueError( + f"Did not get output keys that were expected. " + f"Got: {set(outputs)}. Expected: {set(self.output_keys)}." + ) + + @abstractmethod + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + """Run the logic of this chain and return the output.""" + + def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """Run the logic of this chain and add to output.""" + self._validate_inputs(inputs) + outputs = self._run(inputs) + self._validate_outputs(outputs) + return {**inputs, **outputs} diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py new file mode 100644 index 00000000..7c4f2b44 --- /dev/null +++ b/langchain/chains/llm.py @@ -0,0 +1,46 @@ +"""Chain that just formats a prompt and calls an LLM.""" +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.llms.base import LLM +from langchain.prompt import Prompt + + +class LLMChain(Chain, BaseModel): + """Chain to run queries against LLMs.""" + + prompt: Prompt + llm: LLM + output_key: str = "text" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects.""" + return self.prompt.input_variables + + @property + def output_keys(self) -> List[str]: + """Will always return text key.""" + return [self.output_key] + + def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: + selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} + prompt = self.prompt.format(**selected_inputs) + + kwargs = {} + if "stop" in inputs: + kwargs["stop"] = inputs["stop"] + response = self.llm(prompt, **kwargs) + return {self.output_key: response} + + def predict(self, **kwargs: Any) -> str: + """More user-friendly interface for interacting with LLMs.""" + return self(kwargs)[self.output_key] diff --git a/langchain/chains/llm_math/__init__.py b/langchain/chains/llm_math/__init__.py new file mode 100644 index 00000000..fa9fd272 --- /dev/null +++ b/langchain/chains/llm_math/__init__.py @@ -0,0 +1,4 @@ +"""Chain that interprets a prompt and executes python code to do math. + +Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py +""" diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py new file mode 100644 index 00000000..944b6e25 --- /dev/null +++ b/langchain/chains/llm_math/base.py @@ -0,0 +1,57 @@ +"""Chain that interprets a prompt and executes python code to do math.""" +from typing import Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.llm_math.prompt import PROMPT +from langchain.chains.python import PythonChain +from langchain.llms.base import LLM + + +class LLMMathChain(Chain, BaseModel): + """Chain that interprets a prompt and executes python code to do math.""" + + llm: LLM + verbose: bool = False + input_key: str = "question" + output_key: str = "answer" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key.""" + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) + python_executor = PythonChain() + question = inputs[self.input_key] + t = llm_executor.predict(question=question, stop=["```output"]).strip() + if t.startswith("```python"): + code = t[9:-4] + if self.verbose: + print("[DEBUG] evaluating code") + print(code) + output = python_executor.run(code) + answer = "Answer: " + output + elif t.startswith("Answer:"): + answer = t + else: + raise ValueError(f"unknown format from LLM: {t}") + return {self.output_key: answer} + + def run(self, question: str) -> str: + """More user-friendly interface for interfacing with LLM math.""" + return self({self.input_key: question})[self.output_key] diff --git a/langchain/chains/llm_math/prompt.py b/langchain/chains/llm_math/prompt.py new file mode 100644 index 00000000..a5614b3a --- /dev/null +++ b/langchain/chains/llm_math/prompt.py @@ -0,0 +1,38 @@ +# flake8: noqa +from langchain.prompt import Prompt + +_PROMPT_TEMPLATE = """You are GPT-3, and you can't do math. + +You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers. + +So we hooked you up to a Python 3 kernel, and now you can execute code. If anyone gives you a hard math problem, just use this format and we’ll take care of the rest: + +Question: ${{Question with hard calculation.}} +```python +${{Code that prints what you need to know}} +``` +```output +${{Output of your code}} +``` +Answer: ${{Answer}} + +Otherwise, use this simpler format: + +Question: ${{Question without hard calculation}} +Answer: ${{Answer}} + +Begin. + +Question: What is 37593 * 67? + +```python +print(37593 * 67) +``` +```output +2518731 +``` +Answer: 2518731 + +Question: {question}""" + +PROMPT = Prompt(input_variables=["question"], template=_PROMPT_TEMPLATE) diff --git a/langchain/chains/python.py b/langchain/chains/python.py new file mode 100644 index 00000000..2f81113f --- /dev/null +++ b/langchain/chains/python.py @@ -0,0 +1,40 @@ +"""Chain that runs python code. + +Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py +""" +import sys +from io import StringIO +from typing import Dict, List + +from pydantic import BaseModel + +from langchain.chains.base import Chain + + +class PythonChain(Chain, BaseModel): + """Chain to run python code.""" + + input_key: str = "code" + output_key: str = "output" + + @property + def input_keys(self) -> List[str]: + """Expect input key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key.""" + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + exec(inputs[self.input_key]) + sys.stdout = old_stdout + output = mystdout.getvalue() + return {self.output_key: output} + + def run(self, code: str) -> str: + """More user-friendly interface for interfacing with python.""" + return self({self.input_key: code})[self.output_key] diff --git a/langchain/chains/self_ask_with_search/__init__.py b/langchain/chains/self_ask_with_search/__init__.py new file mode 100644 index 00000000..70a450ac --- /dev/null +++ b/langchain/chains/self_ask_with_search/__init__.py @@ -0,0 +1,4 @@ +"""Chain that does self ask with search. + +Heavily borrowed from https://github.com/ofirpress/self-ask +""" diff --git a/langchain/chains/self_ask_with_search/base.py b/langchain/chains/self_ask_with_search/base.py new file mode 100644 index 00000000..bd1bb8cf --- /dev/null +++ b/langchain/chains/self_ask_with_search/base.py @@ -0,0 +1,142 @@ +"""Chain that does self ask with search.""" +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.self_ask_with_search.prompt import PROMPT +from langchain.chains.serpapi import SerpAPIChain +from langchain.llms.base import LLM + + +def extract_answer(generated: str) -> str: + """Extract answer from text.""" + if "\n" not in generated: + last_line = generated + else: + last_line = generated.split("\n")[-1] + + if ":" not in last_line: + after_colon = last_line + else: + after_colon = generated.split(":")[-1] + + if " " == after_colon[0]: + after_colon = after_colon[1:] + if "." == after_colon[-1]: + after_colon = after_colon[:-1] + + return after_colon + + +def extract_question(generated: str, followup: str) -> str: + """Extract question from text.""" + if "\n" not in generated: + last_line = generated + else: + last_line = generated.split("\n")[-1] + + if followup not in last_line: + print("we probably should never get here..." + generated) + + if ":" not in last_line: + after_colon = last_line + else: + after_colon = generated.split(":")[-1] + + if " " == after_colon[0]: + after_colon = after_colon[1:] + if "?" != after_colon[-1]: + print("we probably should never get here..." + generated) + + return after_colon + + +def get_last_line(generated: str) -> str: + """Get the last line in text.""" + if "\n" not in generated: + last_line = generated + else: + last_line = generated.split("\n")[-1] + + return last_line + + +def greenify(_input: str) -> str: + """Add green highlighting to text.""" + return "\x1b[102m" + _input + "\x1b[0m" + + +def yellowfy(_input: str) -> str: + """Add yellow highlighting to text.""" + return "\x1b[106m" + _input + "\x1b[0m" + + +class SelfAskWithSearchChain(Chain, BaseModel): + """Chain that does self ask with search.""" + + llm: LLM + search_chain: SerpAPIChain + input_key: str = "question" + output_key: str = "answer" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key.""" + return [self.output_key] + + def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: + question = inputs[self.input_key] + llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) + intermediate = "\nIntermediate answer:" + followup = "Follow up:" + finalans = "\nSo the final answer is:" + cur_prompt = f"{question}\nAre follow up questions needed here:" + print(cur_prompt, end="") + ret_text = llm_chain.predict(input=cur_prompt, stop=[intermediate]) + print(greenify(ret_text), end="") + while followup in get_last_line(ret_text): + cur_prompt += ret_text + question = extract_question(ret_text, followup) + external_answer = self.search_chain.search(question) + if external_answer is not None: + cur_prompt += intermediate + " " + external_answer + "." + print( + intermediate + " " + yellowfy(external_answer) + ".", + end="", + ) + ret_text = llm_chain.predict( + input=cur_prompt, stop=["\nIntermediate answer:"] + ) + print(greenify(ret_text), end="") + else: + # We only get here in the very rare case that Google returns no answer. + cur_prompt += intermediate + print(intermediate + " ") + cur_prompt += llm_chain.predict( + input=cur_prompt, stop=["\n" + followup, finalans] + ) + + if finalans not in ret_text: + cur_prompt += finalans + print(finalans, end="") + ret_text = llm_chain.predict(input=cur_prompt, stop=["\n"]) + print(greenify(ret_text), end="") + + return {self.output_key: cur_prompt + ret_text} + + def run(self, question: str) -> str: + """More user-friendly interface for interfacing with self ask with search.""" + return self({self.input_key: question})[self.output_key] diff --git a/langchain/chains/self_ask_with_search/prompt.py b/langchain/chains/self_ask_with_search/prompt.py new file mode 100644 index 00000000..cb52d3c8 --- /dev/null +++ b/langchain/chains/self_ask_with_search/prompt.py @@ -0,0 +1,44 @@ +# flake8: noqa +from langchain.prompt import Prompt + +_DEFAULT_TEMPLATE = """Question: Who lived longer, Muhammad Ali or Alan Turing? +Are follow up questions needed here: Yes. +Follow up: How old was Muhammad Ali when he died? +Intermediate answer: Muhammad Ali was 74 years old when he died. +Follow up: How old was Alan Turing when he died? +Intermediate answer: Alan Turing was 41 years old when he died. +So the final answer is: Muhammad Ali + +Question: When was the founder of craigslist born? +Are follow up questions needed here: Yes. +Follow up: Who was the founder of craigslist? +Intermediate answer: Craigslist was founded by Craig Newmark. +Follow up: When was Craig Newmark born? +Intermediate answer: Craig Newmark was born on December 6, 1952. +So the final answer is: December 6, 1952 + +Question: Who was the maternal grandfather of George Washington? +Are follow up questions needed here: Yes. +Follow up: Who was the mother of George Washington? +Intermediate answer: The mother of George Washington was Mary Ball Washington. +Follow up: Who was the father of Mary Ball Washington? +Intermediate answer: The father of Mary Ball Washington was Joseph Ball. +So the final answer is: Joseph Ball + +Question: Are both the directors of Jaws and Casino Royale from the same country? +Are follow up questions needed here: Yes. +Follow up: Who is the director of Jaws? +Intermediate Answer: The director of Jaws is Steven Spielberg. +Follow up: Where is Steven Spielberg from? +Intermediate Answer: The United States. +Follow up: Who is the director of Casino Royale? +Intermediate Answer: The director of Casino Royale is Martin Campbell. +Follow up: Where is Martin Campbell from? +Intermediate Answer: New Zealand. +So the final answer is: No + +Question: {input}""" +PROMPT = Prompt( + input_variables=["input"], + template=_DEFAULT_TEMPLATE, +) diff --git a/langchain/chains/serpapi.py b/langchain/chains/serpapi.py new file mode 100644 index 00000000..2c7c32da --- /dev/null +++ b/langchain/chains/serpapi.py @@ -0,0 +1,99 @@ +"""Chain that calls SerpAPI. + +Heavily borrowed from https://github.com/ofirpress/self-ask +""" +import os +import sys +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.chains.base import Chain + + +class HiddenPrints: + """Context manager to hide prints.""" + + def __enter__(self) -> None: + """Open file to pipe stdout to.""" + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, *_: Any) -> None: + """Close file that stdout was piped to.""" + sys.stdout.close() + sys.stdout = self._original_stdout + + +class SerpAPIChain(Chain, BaseModel): + """Chain that calls SerpAPI.""" + + search_engine: Any + input_key: str = "search_query" + output_key: str = "search_result" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def input_keys(self) -> List[str]: + """Return the singular input key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key.""" + return [self.output_key] + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "SERPAPI_API_KEY" not in os.environ: + raise ValueError( + "Did not find SerpAPI API key, please add an environment variable" + " `SERPAPI_API_KEY` which contains it." + ) + try: + from serpapi import GoogleSearch + + values["search_engine"] = GoogleSearch + except ImportError: + raise ValueError( + "Could not import serpapi python package. " + "Please it install it with `pip install google-search-results`." + ) + return values + + def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: + params = { + "api_key": os.environ["SERPAPI_API_KEY"], + "engine": "google", + "q": inputs[self.input_key], + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + with HiddenPrints(): + search = self.search_engine(params) + res = search.get_dict() + + if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + toret = res["answer_box"]["answer"] + elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret = res["answer_box"]["snippet"] + elif ( + "answer_box" in res.keys() + and "snippet_highlighted_words" in res["answer_box"].keys() + ): + toret = res["answer_box"]["snippet_highlighted_words"][0] + elif "snippet" in res["organic_results"][0].keys(): + toret = res["organic_results"][0]["snippet"] + else: + toret = None + return {self.output_key: toret} + + def search(self, search_question: str) -> str: + """More user-friendly interface for interfacing with search.""" + return self({self.input_key: search_question})[self.output_key] diff --git a/langchain/formatting.py b/langchain/formatting.py new file mode 100644 index 00000000..61c7c116 --- /dev/null +++ b/langchain/formatting.py @@ -0,0 +1,32 @@ +"""Utilities for formatting strings.""" +from string import Formatter +from typing import Any, Mapping, Sequence, Union + + +class StrictFormatter(Formatter): + """A subclass of formatter that checks for extra keys.""" + + def check_unused_args( + self, + used_args: Sequence[Union[int, str]], + args: Sequence, + kwargs: Mapping[str, Any], + ) -> None: + """Check to see if extra parameters are passed.""" + extra = set(kwargs).difference(used_args) + if extra: + raise KeyError(extra) + + def vformat( + self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] + ) -> str: + """Check that no arguments are provided.""" + if len(args) > 0: + raise ValueError( + "No arguments should be provided, " + "everything should be passed as keyword arguments." + ) + return super().vformat(format_string, args, kwargs) + + +formatter = StrictFormatter() diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py new file mode 100644 index 00000000..2e1720f3 --- /dev/null +++ b/langchain/llms/__init__.py @@ -0,0 +1,5 @@ +"""Wrappers on top of large language models.""" +from langchain.llms.cohere import Cohere +from langchain.llms.openai import OpenAI + +__all__ = ["Cohere", "OpenAI"] diff --git a/langchain/llms/base.py b/langchain/llms/base.py new file mode 100644 index 00000000..56382efd --- /dev/null +++ b/langchain/llms/base.py @@ -0,0 +1,11 @@ +"""Base interface for large language models to expose.""" +from abc import ABC, abstractmethod +from typing import List, Optional + + +class LLM(ABC): + """LLM wrapper should take in a prompt and return a string.""" + + @abstractmethod + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Run the LLM on the given prompt and input.""" diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py new file mode 100644 index 00000000..61504ba6 --- /dev/null +++ b/langchain/llms/cohere.py @@ -0,0 +1,72 @@ +"""Wrapper around Cohere APIs.""" +import os +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM + + +def remove_stop_tokens(text: str, stop: List[str]) -> str: + """Remove stop tokens, should they occur at end.""" + for s in stop: + if text.endswith(s): + return text[: -len(s)] + return text + + +class Cohere(BaseModel, LLM): + """Wrapper around Cohere large language models.""" + + client: Any + model: str = "gptd-instruct-tft" + max_tokens: int = 256 + temperature: float = 0.6 + k: int = 0 + p: int = 1 + frequency_penalty: int = 0 + presence_penalty: int = 0 + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "COHERE_API_KEY" not in os.environ: + raise ValueError( + "Did not find Cohere API key, please add an environment variable" + " `COHERE_API_KEY` which contains it." + ) + try: + import cohere + + values["client"] = cohere.Client(os.environ["COHERE_API_KEY"]) + except ImportError: + raise ValueError( + "Could not import cohere python package. " + "Please it install it with `pip install cohere`." + ) + return values + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to Cohere's generate endpoint.""" + response = self.client.generate( + model=self.model, + prompt=prompt, + max_tokens=self.max_tokens, + temperature=self.temperature, + k=self.k, + p=self.p, + frequency_penalty=self.frequency_penalty, + presence_penalty=self.presence_penalty, + stop_sequences=stop, + ) + text = response.generations[0].text + # If stop tokens are provided, Cohere's endpoint returns them. + # In order to make this consistent with other endpoints, we strip them. + if stop is not None: + text = remove_stop_tokens(text, stop) + return text diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py new file mode 100644 index 00000000..70296a9f --- /dev/null +++ b/langchain/llms/openai.py @@ -0,0 +1,65 @@ +"""Wrapper around OpenAI APIs.""" +import os +from typing import Any, Dict, List, Mapping, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM + + +class OpenAI(BaseModel, LLM): + """Wrapper around OpenAI large language models.""" + + client: Any + model_name: str = "text-davinci-002" + temperature: float = 0.7 + max_tokens: int = 256 + top_p: int = 1 + frequency_penalty: int = 0 + presence_penalty: int = 0 + n: int = 1 + best_of: int = 1 + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "OPENAI_API_KEY" not in os.environ: + raise ValueError( + "Did not find OpenAI API key, please add an environment variable" + " `OPENAI_API_KEY` which contains it." + ) + try: + import openai + + values["client"] = openai.Completion + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please it install it with `pip install openai`." + ) + return values + + @property + def default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "n": self.n, + "best_of": self.best_of, + } + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to OpenAI's create endpoint.""" + response = self.client.create( + model=self.model_name, prompt=prompt, stop=stop, **self.default_params + ) + return response["choices"][0]["text"] diff --git a/langchain/prompt.py b/langchain/prompt.py new file mode 100644 index 00000000..48ddd759 --- /dev/null +++ b/langchain/prompt.py @@ -0,0 +1,47 @@ +"""Prompt schema definition.""" +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.formatting import formatter + +_FORMATTER_MAPPING = { + "f-string": formatter.format, +} + + +class Prompt(BaseModel): + """Schema to represent a prompt for an LLM.""" + + input_variables: List[str] + template: str + template_format: str = "f-string" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs.""" + return _FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" + input_variables = values["input_variables"] + template = values["template"] + template_format = values["template_format"] + if template_format not in _FORMATTER_MAPPING: + valid_formats = list(_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + try: + formatter_func = _FORMATTER_MAPPING[template_format] + formatter_func(template, **dummy_inputs) + except KeyError: + raise ValueError("Invalid prompt schema.") + return values diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..8eedb8d8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.isort] +profile = "black" + +[tool.mypy] +ignore_missing_imports = "True" +disallow_untyped_defs = "True" +exclude = ["notebooks"] diff --git a/readthedocs.yml b/readthedocs.yml new file mode 100644 index 00000000..832cc62f --- /dev/null +++ b/readthedocs.yml @@ -0,0 +1,10 @@ +version: 2 +sphinx: + configuration: docs/conf.py +formats: all +python: + version: 3.6 + install: + - requirements: docs/requirements.txt + - method: pip + path: . diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..4f1a9711 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +-r test_requirements.txt +black +isort +mypy +flake8 +flake8-docstrings +cohere +openai +google-search-results diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..b84bd609 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +"""Set up the package.""" +from pathlib import Path + +from setuptools import find_packages, setup + +with open(Path(__file__).absolute().parents[0] / "langchain" / "VERSION") as _f: + __version__ = _f.read().strip() + +with open("README.md", "r") as f: + long_description = f.read() + +setup( + name="langchain", + version=__version__, + packages=find_packages(), + description="Building applications with LLMs through composability", + install_requires=["pydantic"], + long_description=long_description, + license="MIT", + url="https://github.com/hwchase17/langchain", + include_package_data=True, + long_description_content_type="text/markdown", +) diff --git a/test_requirements.txt b/test_requirements.txt new file mode 100644 index 00000000..aea9aec7 --- /dev/null +++ b/test_requirements.txt @@ -0,0 +1,3 @@ +-e . +pytest +pytest-dotenv diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..4c210e33 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""All tests for this package.""" diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 00000000..a076291f --- /dev/null +++ b/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""All integration tests (tests that call out to an external API).""" diff --git a/tests/integration_tests/chains/__init__.py b/tests/integration_tests/chains/__init__.py new file mode 100644 index 00000000..3ca24201 --- /dev/null +++ b/tests/integration_tests/chains/__init__.py @@ -0,0 +1 @@ +"""All integration tests for chains.""" diff --git a/tests/integration_tests/chains/test_self_ask_with_search.py b/tests/integration_tests/chains/test_self_ask_with_search.py new file mode 100644 index 00000000..7ec49a8b --- /dev/null +++ b/tests/integration_tests/chains/test_self_ask_with_search.py @@ -0,0 +1,18 @@ +"""Integration test for self ask with search.""" +from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain +from langchain.chains.serpapi import SerpAPIChain +from langchain.llms.openai import OpenAI + + +def test_self_ask_with_search() -> None: + """Test functionality on a prompt.""" + question = "What is the hometown of the reigning men's U.S. Open champion?" + chain = SelfAskWithSearchChain( + llm=OpenAI(temperature=0), + search_chain=SerpAPIChain(), + input_key="q", + output_key="a", + ) + answer = chain.run(question) + final_answer = answer.split("\n")[-1] + assert final_answer == "So the final answer is: El Palmar, Murcia, Spain" diff --git a/tests/integration_tests/chains/test_serpapi.py b/tests/integration_tests/chains/test_serpapi.py new file mode 100644 index 00000000..5f4aa048 --- /dev/null +++ b/tests/integration_tests/chains/test_serpapi.py @@ -0,0 +1,9 @@ +"""Integration test for SerpAPI.""" +from langchain.chains.serpapi import SerpAPIChain + + +def test_call() -> None: + """Test that call gives the correct answer.""" + chain = SerpAPIChain() + output = chain.search("What was Obama's first name?") + assert output == "Barack Hussein Obama II" diff --git a/tests/integration_tests/llms/__init__.py b/tests/integration_tests/llms/__init__.py new file mode 100644 index 00000000..6ad06b85 --- /dev/null +++ b/tests/integration_tests/llms/__init__.py @@ -0,0 +1 @@ +"""All integration tests for LLM objects.""" diff --git a/tests/integration_tests/llms/test_cohere.py b/tests/integration_tests/llms/test_cohere.py new file mode 100644 index 00000000..f1a8a6c3 --- /dev/null +++ b/tests/integration_tests/llms/test_cohere.py @@ -0,0 +1,10 @@ +"""Test Cohere API wrapper.""" + +from langchain.llms.cohere import Cohere + + +def test_cohere_call() -> None: + """Test valid call to cohere.""" + llm = Cohere(max_tokens=10) + output = llm("Say foo:") + assert isinstance(output, str) diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py new file mode 100644 index 00000000..850e91f4 --- /dev/null +++ b/tests/integration_tests/llms/test_openai.py @@ -0,0 +1,10 @@ +"""Test OpenAI API wrapper.""" + +from langchain.llms.openai import OpenAI + + +def test_cohere_call() -> None: + """Test valid call to cohere.""" + llm = OpenAI(max_tokens=10) + output = llm("Say foo:") + assert isinstance(output, str) diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 00000000..307b5085 --- /dev/null +++ b/tests/unit_tests/__init__.py @@ -0,0 +1 @@ +"""All unit tests (lightweight tests).""" diff --git a/tests/unit_tests/chains/__init__.py b/tests/unit_tests/chains/__init__.py new file mode 100644 index 00000000..e1765c67 --- /dev/null +++ b/tests/unit_tests/chains/__init__.py @@ -0,0 +1 @@ +"""Tests for correct functioning of chains.""" diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py new file mode 100644 index 00000000..95ebb6f7 --- /dev/null +++ b/tests/unit_tests/chains/test_base.py @@ -0,0 +1,50 @@ +"""Test logic on base chain class.""" +from typing import Dict, List + +import pytest +from pydantic import BaseModel + +from langchain.chains.base import Chain + + +class FakeChain(Chain, BaseModel): + """Fake chain class for testing purposes.""" + + be_correct: bool = True + + @property + def input_keys(self) -> List[str]: + """Input key of foo.""" + return ["foo"] + + @property + def output_keys(self) -> List[str]: + """Output key of bar.""" + return ["bar"] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + if self.be_correct: + return {"bar": "baz"} + else: + return {"baz": "bar"} + + +def test_bad_inputs() -> None: + """Test errors are raised if input keys are not found.""" + chain = FakeChain() + with pytest.raises(ValueError): + chain({"foobar": "baz"}) + + +def test_bad_outputs() -> None: + """Test errors are raised if outputs keys are not found.""" + chain = FakeChain(be_correct=False) + with pytest.raises(ValueError): + chain({"foo": "baz"}) + + +def test_correct_call() -> None: + """Test correct call of fake chain.""" + chain = FakeChain() + output = chain({"foo": "bar"}) + assert output == {"foo": "bar", "bar": "baz"} diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py new file mode 100644 index 00000000..4c350637 --- /dev/null +++ b/tests/unit_tests/chains/test_llm.py @@ -0,0 +1,36 @@ +"""Test LLM chain.""" +import pytest + +from langchain.chains.llm import LLMChain +from langchain.prompt import Prompt +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.fixture +def fake_llm_chain() -> LLMChain: + """Fake LLM chain for testing purposes.""" + prompt = Prompt(input_variables=["bar"], template="This is a {bar}:") + return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") + + +def test_missing_inputs(fake_llm_chain: LLMChain) -> None: + """Test error is raised if inputs are missing.""" + with pytest.raises(ValueError): + fake_llm_chain({"foo": "bar"}) + + +def test_valid_call(fake_llm_chain: LLMChain) -> None: + """Test valid call of LLM chain.""" + output = fake_llm_chain({"bar": "baz"}) + assert output == {"bar": "baz", "text1": "foo"} + + # Test with stop words. + output = fake_llm_chain({"bar": "baz", "stop": ["foo"]}) + # Response should be `bar` now. + assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"} + + +def test_predict_method(fake_llm_chain: LLMChain) -> None: + """Test predict method works.""" + output = fake_llm_chain.predict(bar="baz") + assert output == "foo" diff --git a/tests/unit_tests/chains/test_llm_math.py b/tests/unit_tests/chains/test_llm_math.py new file mode 100644 index 00000000..b38d89dd --- /dev/null +++ b/tests/unit_tests/chains/test_llm_math.py @@ -0,0 +1,40 @@ +"""Test LLM Math functionality.""" + +import pytest + +from langchain.chains.llm_math.base import LLMMathChain +from langchain.chains.llm_math.prompt import _PROMPT_TEMPLATE +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.fixture +def fake_llm_math_chain() -> LLMMathChain: + """Fake LLM Math chain for testing.""" + complex_question = _PROMPT_TEMPLATE.format(question="What is the square root of 2?") + queries = { + _PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2", + complex_question: "```python\nprint(2**.5)\n```", + _PROMPT_TEMPLATE.format(question="foo"): "foo", + } + fake_llm = FakeLLM(queries=queries) + return LLMMathChain(llm=fake_llm, input_key="q", output_key="a") + + +def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None: + """Test simple question that should not need python.""" + question = "What is 1 plus 1?" + output = fake_llm_math_chain.run(question) + assert output == "Answer: 2" + + +def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None: + """Test complex question that should need python.""" + question = "What is the square root of 2?" + output = fake_llm_math_chain.run(question) + assert output == f"Answer: {2**.5}\n" + + +def test_error(fake_llm_math_chain: LLMMathChain) -> None: + """Test question that raises error.""" + with pytest.raises(ValueError): + fake_llm_math_chain.run("foo") diff --git a/tests/unit_tests/chains/test_python.py b/tests/unit_tests/chains/test_python.py new file mode 100644 index 00000000..1677a76a --- /dev/null +++ b/tests/unit_tests/chains/test_python.py @@ -0,0 +1,15 @@ +"""Test python chain.""" + +from langchain.chains.python import PythonChain + + +def test_functionality() -> None: + """Test correct functionality.""" + chain = PythonChain(input_key="code1", output_key="output1") + code = "print(1 + 1)" + output = chain({"code1": code}) + assert output == {"code1": code, "output1": "2\n"} + + # Test with the more user-friendly interface. + simple_output = chain.run(code) + assert simple_output == "2\n" diff --git a/tests/unit_tests/data/prompts/prompt_extra_args.json b/tests/unit_tests/data/prompts/prompt_extra_args.json new file mode 100644 index 00000000..4bfc4fdc --- /dev/null +++ b/tests/unit_tests/data/prompts/prompt_extra_args.json @@ -0,0 +1,5 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test.", + "bad_var": 1 +} \ No newline at end of file diff --git a/tests/unit_tests/data/prompts/prompt_missing_args.json b/tests/unit_tests/data/prompts/prompt_missing_args.json new file mode 100644 index 00000000..cb69d843 --- /dev/null +++ b/tests/unit_tests/data/prompts/prompt_missing_args.json @@ -0,0 +1,3 @@ +{ + "input_variables": ["foo"] +} \ No newline at end of file diff --git a/tests/unit_tests/data/prompts/simple_prompt.json b/tests/unit_tests/data/prompts/simple_prompt.json new file mode 100644 index 00000000..d0f72b1c --- /dev/null +++ b/tests/unit_tests/data/prompts/simple_prompt.json @@ -0,0 +1,4 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test." +} \ No newline at end of file diff --git a/tests/unit_tests/llms/__init__.py b/tests/unit_tests/llms/__init__.py new file mode 100644 index 00000000..95bd682b --- /dev/null +++ b/tests/unit_tests/llms/__init__.py @@ -0,0 +1 @@ +"""All unit tests for LLM objects.""" diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py new file mode 100644 index 00000000..60e3d871 --- /dev/null +++ b/tests/unit_tests/llms/fake_llm.py @@ -0,0 +1,21 @@ +"""Fake LLM wrapper for testing purposes.""" +from typing import List, Mapping, Optional + +from langchain.llms.base import LLM + + +class FakeLLM(LLM): + """Fake LLM wrapper for testing purposes.""" + + def __init__(self, queries: Optional[Mapping] = None): + """Initialize with optional lookup of queries.""" + self._queries = queries + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Return `foo` if no stop words, otherwise `bar`.""" + if self._queries is not None: + return self._queries[prompt] + if stop is None: + return "foo" + else: + return "bar" diff --git a/tests/unit_tests/llms/test_cohere.py b/tests/unit_tests/llms/test_cohere.py new file mode 100644 index 00000000..9e30c333 --- /dev/null +++ b/tests/unit_tests/llms/test_cohere.py @@ -0,0 +1,17 @@ +"""Test helper functions for Cohere API.""" + +from langchain.llms.cohere import remove_stop_tokens + + +def test_remove_stop_tokens() -> None: + """Test removing stop tokens when they occur.""" + text = "foo bar baz" + output = remove_stop_tokens(text, ["moo", "baz"]) + assert output == "foo bar " + + +def test_remove_stop_tokens_none() -> None: + """Test removing stop tokens when they do not occur.""" + text = "foo bar baz" + output = remove_stop_tokens(text, ["moo"]) + assert output == "foo bar baz" diff --git a/tests/unit_tests/test_formatting.py b/tests/unit_tests/test_formatting.py new file mode 100644 index 00000000..168e580b --- /dev/null +++ b/tests/unit_tests/test_formatting.py @@ -0,0 +1,26 @@ +"""Test formatting functionality.""" +import pytest + +from langchain.formatting import formatter + + +def test_valid_formatting() -> None: + """Test formatting works as expected.""" + template = "This is a {foo} test." + output = formatter.format(template, foo="good") + expected_output = "This is a good test." + assert output == expected_output + + +def test_does_not_allow_args() -> None: + """Test formatting raises error when args are provided.""" + template = "This is a {} test." + with pytest.raises(ValueError): + formatter.format(template, "good") + + +def test_does_not_allow_extra_kwargs() -> None: + """Test formatting does not allow extra key word arguments.""" + template = "This is a {foo} test." + with pytest.raises(KeyError): + formatter.format(template, foo="good", bar="oops") diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py new file mode 100644 index 00000000..07624869 --- /dev/null +++ b/tests/unit_tests/test_prompt.py @@ -0,0 +1,47 @@ +"""Test functionality related to prompts.""" +import pytest + +from langchain.prompt import Prompt + + +def test_prompt_valid() -> None: + """Test prompts can be constructed.""" + template = "This is a {foo} test." + input_variables = ["foo"] + prompt = Prompt(input_variables=input_variables, template=template) + assert prompt.template == template + assert prompt.input_variables == input_variables + + +def test_prompt_missing_input_variables() -> None: + """Test error is raised when input variables are not provided.""" + template = "This is a {foo} test." + input_variables: list = [] + with pytest.raises(ValueError): + Prompt(input_variables=input_variables, template=template) + + +def test_prompt_extra_input_variables() -> None: + """Test error is raised when there are too many input variables.""" + template = "This is a {foo} test." + input_variables = ["foo", "bar"] + with pytest.raises(ValueError): + Prompt(input_variables=input_variables, template=template) + + +def test_prompt_wrong_input_variables() -> None: + """Test error is raised when name of input variable is wrong.""" + template = "This is a {foo} test." + input_variables = ["bar"] + with pytest.raises(ValueError): + Prompt(input_variables=input_variables, template=template) + + +def test_prompt_invalid_template_format() -> None: + """Test initializing a prompt with invalid template format.""" + template = "This is a {foo} test." + input_variables = ["foo"] + with pytest.raises(ValueError): + Prompt( + input_variables=input_variables, template=template, template_format="bar" + )