forked from Archives/langchain
initial commit
commit
18aeb72012
@ -0,0 +1,11 @@
|
||||
[flake8]
|
||||
exclude =
|
||||
.venv
|
||||
__pycache__
|
||||
notebooks
|
||||
# Recommend matching the black line length (default 88),
|
||||
# rather than using the flake8 default of 79:
|
||||
max-line-length = 88
|
||||
extend-ignore =
|
||||
# See https://github.com/PyCQA/pycodestyle/issues/373
|
||||
E203,
|
@ -0,0 +1,23 @@
|
||||
name: lint
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.7"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
- name: Analysing the code with our lint
|
||||
run: |
|
||||
make lint
|
@ -0,0 +1,23 @@
|
||||
name: test
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.7"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r test_requirements.txt
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
make tests
|
@ -0,0 +1,130 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
notebooks/
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
@ -0,0 +1,2 @@
|
||||
include langchain/VERSION
|
||||
include LICENSE
|
@ -0,0 +1,17 @@
|
||||
.PHONY: format lint tests integration_tests
|
||||
|
||||
format:
|
||||
black .
|
||||
isort .
|
||||
|
||||
lint:
|
||||
mypy .
|
||||
black . --check
|
||||
isort . --check
|
||||
flake8 .
|
||||
|
||||
tests:
|
||||
pytest tests/unit_tests
|
||||
|
||||
integration_tests:
|
||||
pytest tests/integration_tests
|
@ -0,0 +1,80 @@
|
||||
# 🦜️🔗 LangChain
|
||||
|
||||
⚡ Building applications with LLMs through composability ⚡
|
||||
|
||||
[![lint](https://github.com/hwchase17/langchain/actions/workflows/lint.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [![test](https://github.com/hwchase17/langchain/actions/workflows/test.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
|
||||
## Quick Install
|
||||
|
||||
`pip install langchain`
|
||||
|
||||
## 🤔 What is this?
|
||||
|
||||
Large language models (LLMs) are emerging as a transformative technology, enabling
|
||||
developers to build applications that they previously could not.
|
||||
But using these LLMs in isolation is often not enough to
|
||||
create a truly powerful app - the real power comes when you are able to
|
||||
combine them with other sources of computation or knowledge.
|
||||
|
||||
This library is aimed at assisting in the development of those types of applications.
|
||||
It aims to create:
|
||||
1. a comprehensive collection of pieces you would ever want to combine
|
||||
2. a flexible interface for combining pieces into a single comprehensive "chain"
|
||||
3. a schema for easily saving and sharing those chains
|
||||
|
||||
## 🚀 What can I do with this
|
||||
|
||||
This project was largely inspired by a few projects seen on Twitter for which we thought it would make sense to have more explicit tooling. A lot of the initial functionality was done in an attempt to recreate those. Those are:
|
||||
|
||||
**[Self-ask-with-search](https://ofir.io/self-ask.pdf)**
|
||||
|
||||
To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/self_ask_with_search.ipynb).
|
||||
|
||||
```
|
||||
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
search = SerpAPIChain()
|
||||
|
||||
self_ask_with_search = SelfAskWithSearchChain(llm=llm, search_chain=search)
|
||||
|
||||
self_ask_with_search.run("What is the hometown of the reigning men's U.S. Open champion?")
|
||||
```
|
||||
|
||||
**[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)**
|
||||
|
||||
To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/llm_math.ipynb).
|
||||
|
||||
```
|
||||
from langchain import OpenAI, LLMMathChain
|
||||
|
||||
llm = OpenAI(temperature=0)
|
||||
llm_math = LLMMathChain(llm=llm)
|
||||
|
||||
llm_math.run("How many of the integers between 0 and 99 inclusive are divisible by 8?")
|
||||
```
|
||||
|
||||
**Generic Prompting**
|
||||
|
||||
You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/simple_prompts.ipynb).
|
||||
|
||||
```
|
||||
from langchain import Prompt, OpenAI, LLMChain
|
||||
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
prompt = Prompt(template=template, input_variables=["question"])
|
||||
llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0))
|
||||
|
||||
question = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
|
||||
|
||||
llm_chain.predict(question=question)
|
||||
```
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
The above examples are probably the most user friendly documentation that exists,
|
||||
but full API docs can be found [here](https://langchain.readthedocs.io/en/latest/?).
|
@ -0,0 +1,21 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SPHINXAUTOBUILD ?= sphinx-autobuild
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@ -0,0 +1,65 @@
|
||||
"""Configuration file for the Sphinx documentation builder."""
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
import langchain
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "LangChain"
|
||||
copyright = "2022, Harrison Chase"
|
||||
author = "Harrison Chase"
|
||||
|
||||
version = langchain.__version__
|
||||
release = langchain.__version__
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autodoc.typehints",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.napoleon",
|
||||
]
|
||||
|
||||
# autodoc_typehints = "signature"
|
||||
autodoc_typehints = "description"
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
# html_theme = "sphinx_typlog_theme"
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_static"]
|
@ -0,0 +1,10 @@
|
||||
Welcome to LangChain
|
||||
==========================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: User API
|
||||
|
||||
modules/prompt
|
||||
modules/llms
|
||||
modules/chains
|
@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
@ -0,0 +1,7 @@
|
||||
:mod:`langchain.chains`
|
||||
=======================
|
||||
|
||||
.. automodule:: langchain.chains
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -0,0 +1,6 @@
|
||||
:mod:`langchain.llms`
|
||||
=======================
|
||||
|
||||
.. automodule:: langchain.llms
|
||||
:members:
|
||||
:undoc-members:
|
@ -0,0 +1,6 @@
|
||||
:mod:`langchain.prompt`
|
||||
=======================
|
||||
|
||||
.. automodule:: langchain.prompt
|
||||
:members:
|
||||
:undoc-members:
|
@ -0,0 +1,4 @@
|
||||
sphinx==4.5.0
|
||||
sphinx-autobuild==2021.3.14
|
||||
sphinx_rtd_theme==1.0.0
|
||||
sphinx-typlog-theme==0.8.0
|
@ -0,0 +1,59 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "44e9ba31",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Answer: 13\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain import OpenAI, LLMMathChain\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"llm_math = LLMMathChain(llm=llm)\n",
|
||||
"\n",
|
||||
"llm_math.run(\"How many of the integers between 0 and 99 inclusive are divisible by 8?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f62f0c75",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,74 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "7e3b513e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||
"Are follow up questions needed here:\u001b[102m Yes.\n",
|
||||
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
|
||||
"Intermediate answer: \u001b[106mCarlos Alcaraz\u001b[0m.\u001b[102m\n",
|
||||
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
|
||||
"Intermediate answer: \u001b[106mEl Palmar, Murcia, Spain\u001b[0m.\u001b[102m\n",
|
||||
"So the final answer is: El Palmar, Murcia, Spain\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"What is the hometown of the reigning men's U.S. Open champion?\\nAre follow up questions needed here: Yes.\\nFollow up: Who is the reigning men's U.S. Open champion?\\nIntermediate answer: Carlos Alcaraz.\\nFollow up: Where is Carlos Alcaraz from?\\nIntermediate answer: El Palmar, Murcia, Spain.\\nSo the final answer is: El Palmar, Murcia, Spain\""
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"\n",
|
||||
"self_ask_with_search = SelfAskWithSearchChain(llm=llm, search_chain=search)\n",
|
||||
"\n",
|
||||
"self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6195fc82",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "51a54c4d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' The year Justin Beiber was born was 1994. In 1994, the Dallas Cowboys won the Super Bowl.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain import Prompt, OpenAI, LLMChain\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"prompt = Prompt(template=template, input_variables=[\"question\"])\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0))\n",
|
||||
"\n",
|
||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||
"\n",
|
||||
"llm_chain.predict(question=question)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03dd6918",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1 @@
|
||||
0.0.1
|
@ -0,0 +1,27 @@
|
||||
"""Main entrypoint into package."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
|
||||
__version__ = _f.read().strip()
|
||||
|
||||
from langchain.chains import (
|
||||
LLMChain,
|
||||
LLMMathChain,
|
||||
PythonChain,
|
||||
SelfAskWithSearchChain,
|
||||
SerpAPIChain,
|
||||
)
|
||||
from langchain.llms import Cohere, OpenAI
|
||||
from langchain.prompt import Prompt
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIChain",
|
||||
"Cohere",
|
||||
"OpenAI",
|
||||
"Prompt",
|
||||
]
|
@ -0,0 +1,14 @@
|
||||
"""Chains are easily reusable components which can be linked together."""
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIChain",
|
||||
]
|
@ -0,0 +1,41 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class Chain(ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, str]) -> None:
|
||||
if set(outputs) != set(self.output_keys):
|
||||
raise ValueError(
|
||||
f"Did not get output keys that were expected. "
|
||||
f"Got: {set(outputs)}. Expected: {set(self.output_keys)}."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and add to output."""
|
||||
self._validate_inputs(inputs)
|
||||
outputs = self._run(inputs)
|
||||
self._validate_outputs(outputs)
|
||||
return {**inputs, **outputs}
|
@ -0,0 +1,46 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompt import Prompt
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
"""Chain to run queries against LLMs."""
|
||||
|
||||
prompt: Prompt
|
||||
llm: LLM
|
||||
output_key: str = "text"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects."""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key."""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
|
||||
kwargs = {}
|
||||
if "stop" in inputs:
|
||||
kwargs["stop"] = inputs["stop"]
|
||||
response = self.llm(prompt, **kwargs)
|
||||
return {self.output_key: response}
|
||||
|
||||
def predict(self, **kwargs: Any) -> str:
|
||||
"""More user-friendly interface for interacting with LLMs."""
|
||||
return self(kwargs)[self.output_key]
|
@ -0,0 +1,4 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math.
|
||||
|
||||
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
|
||||
"""
|
@ -0,0 +1,57 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class LLMMathChain(Chain, BaseModel):
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
|
||||
llm: LLM
|
||||
verbose: bool = False
|
||||
input_key: str = "question"
|
||||
output_key: str = "answer"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key."""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key."""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
python_executor = PythonChain()
|
||||
question = inputs[self.input_key]
|
||||
t = llm_executor.predict(question=question, stop=["```output"]).strip()
|
||||
if t.startswith("```python"):
|
||||
code = t[9:-4]
|
||||
if self.verbose:
|
||||
print("[DEBUG] evaluating code")
|
||||
print(code)
|
||||
output = python_executor.run(code)
|
||||
answer = "Answer: " + output
|
||||
elif t.startswith("Answer:"):
|
||||
answer = t
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {t}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""More user-friendly interface for interfacing with LLM math."""
|
||||
return self({self.input_key: question})[self.output_key]
|
@ -0,0 +1,40 @@
|
||||
"""Chain that runs python code.
|
||||
|
||||
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
|
||||
"""
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class PythonChain(Chain, BaseModel):
|
||||
"""Chain to run python code."""
|
||||
|
||||
input_key: str = "code"
|
||||
output_key: str = "output"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key."""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key."""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
exec(inputs[self.input_key])
|
||||
sys.stdout = old_stdout
|
||||
output = mystdout.getvalue()
|
||||
return {self.output_key: output}
|
||||
|
||||
def run(self, code: str) -> str:
|
||||
"""More user-friendly interface for interfacing with python."""
|
||||
return self({self.input_key: code})[self.output_key]
|
@ -0,0 +1,4 @@
|
||||
"""Chain that does self ask with search.
|
||||
|
||||
Heavily borrowed from https://github.com/ofirpress/self-ask
|
||||
"""
|
@ -0,0 +1,142 @@
|
||||
"""Chain that does self ask with search."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
def extract_answer(generated: str) -> str:
|
||||
"""Extract answer from text."""
|
||||
if "\n" not in generated:
|
||||
last_line = generated
|
||||
else:
|
||||
last_line = generated.split("\n")[-1]
|
||||
|
||||
if ":" not in last_line:
|
||||
after_colon = last_line
|
||||
else:
|
||||
after_colon = generated.split(":")[-1]
|
||||
|
||||
if " " == after_colon[0]:
|
||||
after_colon = after_colon[1:]
|
||||
if "." == after_colon[-1]:
|
||||
after_colon = after_colon[:-1]
|
||||
|
||||
return after_colon
|
||||
|
||||
|
||||
def extract_question(generated: str, followup: str) -> str:
|
||||
"""Extract question from text."""
|
||||
if "\n" not in generated:
|
||||
last_line = generated
|
||||
else:
|
||||
last_line = generated.split("\n")[-1]
|
||||
|
||||
if followup not in last_line:
|
||||
print("we probably should never get here..." + generated)
|
||||
|
||||
if ":" not in last_line:
|
||||
after_colon = last_line
|
||||
else:
|
||||
after_colon = generated.split(":")[-1]
|
||||
|
||||
if " " == after_colon[0]:
|
||||
after_colon = after_colon[1:]
|
||||
if "?" != after_colon[-1]:
|
||||
print("we probably should never get here..." + generated)
|
||||
|
||||
return after_colon
|
||||
|
||||
|
||||
def get_last_line(generated: str) -> str:
|
||||
"""Get the last line in text."""
|
||||
if "\n" not in generated:
|
||||
last_line = generated
|
||||
else:
|
||||
last_line = generated.split("\n")[-1]
|
||||
|
||||
return last_line
|
||||
|
||||
|
||||
def greenify(_input: str) -> str:
|
||||
"""Add green highlighting to text."""
|
||||
return "\x1b[102m" + _input + "\x1b[0m"
|
||||
|
||||
|
||||
def yellowfy(_input: str) -> str:
|
||||
"""Add yellow highlighting to text."""
|
||||
return "\x1b[106m" + _input + "\x1b[0m"
|
||||
|
||||
|
||||
class SelfAskWithSearchChain(Chain, BaseModel):
|
||||
"""Chain that does self ask with search."""
|
||||
|
||||
llm: LLM
|
||||
search_chain: SerpAPIChain
|
||||
input_key: str = "question"
|
||||
output_key: str = "answer"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key."""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key."""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
intermediate = "\nIntermediate answer:"
|
||||
followup = "Follow up:"
|
||||
finalans = "\nSo the final answer is:"
|
||||
cur_prompt = f"{question}\nAre follow up questions needed here:"
|
||||
print(cur_prompt, end="")
|
||||
ret_text = llm_chain.predict(input=cur_prompt, stop=[intermediate])
|
||||
print(greenify(ret_text), end="")
|
||||
while followup in get_last_line(ret_text):
|
||||
cur_prompt += ret_text
|
||||
question = extract_question(ret_text, followup)
|
||||
external_answer = self.search_chain.search(question)
|
||||
if external_answer is not None:
|
||||
cur_prompt += intermediate + " " + external_answer + "."
|
||||
print(
|
||||
intermediate + " " + yellowfy(external_answer) + ".",
|
||||
end="",
|
||||
)
|
||||
ret_text = llm_chain.predict(
|
||||
input=cur_prompt, stop=["\nIntermediate answer:"]
|
||||
)
|
||||
print(greenify(ret_text), end="")
|
||||
else:
|
||||
# We only get here in the very rare case that Google returns no answer.
|
||||
cur_prompt += intermediate
|
||||
print(intermediate + " ")
|
||||
cur_prompt += llm_chain.predict(
|
||||
input=cur_prompt, stop=["\n" + followup, finalans]
|
||||
)
|
||||
|
||||
if finalans not in ret_text:
|
||||
cur_prompt += finalans
|
||||
print(finalans, end="")
|
||||
ret_text = llm_chain.predict(input=cur_prompt, stop=["\n"])
|
||||
print(greenify(ret_text), end="")
|
||||
|
||||
return {self.output_key: cur_prompt + ret_text}
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""More user-friendly interface for interfacing with self ask with search."""
|
||||
return self({self.input_key: question})[self.output_key]
|
@ -0,0 +1,44 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompt import Prompt
|
||||
|
||||
_DEFAULT_TEMPLATE = """Question: Who lived longer, Muhammad Ali or Alan Turing?
|
||||
Are follow up questions needed here: Yes.
|
||||
Follow up: How old was Muhammad Ali when he died?
|
||||
Intermediate answer: Muhammad Ali was 74 years old when he died.
|
||||
Follow up: How old was Alan Turing when he died?
|
||||
Intermediate answer: Alan Turing was 41 years old when he died.
|
||||
So the final answer is: Muhammad Ali
|
||||
|
||||
Question: When was the founder of craigslist born?
|
||||
Are follow up questions needed here: Yes.
|
||||
Follow up: Who was the founder of craigslist?
|
||||
Intermediate answer: Craigslist was founded by Craig Newmark.
|
||||
Follow up: When was Craig Newmark born?
|
||||
Intermediate answer: Craig Newmark was born on December 6, 1952.
|
||||
So the final answer is: December 6, 1952
|
||||
|
||||
Question: Who was the maternal grandfather of George Washington?
|
||||
Are follow up questions needed here: Yes.
|
||||
Follow up: Who was the mother of George Washington?
|
||||
Intermediate answer: The mother of George Washington was Mary Ball Washington.
|
||||
Follow up: Who was the father of Mary Ball Washington?
|
||||
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
|
||||
So the final answer is: Joseph Ball
|
||||
|
||||
Question: Are both the directors of Jaws and Casino Royale from the same country?
|
||||
Are follow up questions needed here: Yes.
|
||||
Follow up: Who is the director of Jaws?
|
||||
Intermediate Answer: The director of Jaws is Steven Spielberg.
|
||||
Follow up: Where is Steven Spielberg from?
|
||||
Intermediate Answer: The United States.
|
||||
Follow up: Who is the director of Casino Royale?
|
||||
Intermediate Answer: The director of Casino Royale is Martin Campbell.
|
||||
Follow up: Where is Martin Campbell from?
|
||||
Intermediate Answer: New Zealand.
|
||||
So the final answer is: No
|
||||
|
||||
Question: {input}"""
|
||||
PROMPT = Prompt(
|
||||
input_variables=["input"],
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
)
|
@ -0,0 +1,99 @@
|
||||
"""Chain that calls SerpAPI.
|
||||
|
||||
Heavily borrowed from https://github.com/ofirpress/self-ask
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class HiddenPrints:
|
||||
"""Context manager to hide prints."""
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Open file to pipe stdout to."""
|
||||
self._original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
"""Close file that stdout was piped to."""
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
class SerpAPIChain(Chain, BaseModel):
|
||||
"""Chain that calls SerpAPI."""
|
||||
|
||||
search_engine: Any
|
||||
input_key: str = "search_query"
|
||||
output_key: str = "search_result"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key."""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key."""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if "SERPAPI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Did not find SerpAPI API key, please add an environment variable"
|
||||
" `SERPAPI_API_KEY` which contains it."
|
||||
)
|
||||
try:
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
values["search_engine"] = GoogleSearch
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import serpapi python package. "
|
||||
"Please it install it with `pip install google-search-results`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
params = {
|
||||
"api_key": os.environ["SERPAPI_API_KEY"],
|
||||
"engine": "google",
|
||||
"q": inputs[self.input_key],
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
}
|
||||
with HiddenPrints():
|
||||
search = self.search_engine(params)
|
||||
res = search.get_dict()
|
||||
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
else:
|
||||
toret = None
|
||||
return {self.output_key: toret}
|
||||
|
||||
def search(self, search_question: str) -> str:
|
||||
"""More user-friendly interface for interfacing with search."""
|
||||
return self({self.input_key: search_question})[self.output_key]
|
@ -0,0 +1,32 @@
|
||||
"""Utilities for formatting strings."""
|
||||
from string import Formatter
|
||||
from typing import Any, Mapping, Sequence, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
"""A subclass of formatter that checks for extra keys."""
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Check to see if extra parameters are passed."""
|
||||
extra = set(kwargs).difference(used_args)
|
||||
if extra:
|
||||
raise KeyError(extra)
|
||||
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||
) -> str:
|
||||
"""Check that no arguments are provided."""
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"No arguments should be provided, "
|
||||
"everything should be passed as keyword arguments."
|
||||
)
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
@ -0,0 +1,5 @@
|
||||
"""Wrappers on top of large language models."""
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
__all__ = ["Cohere", "OpenAI"]
|
@ -0,0 +1,11 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class LLM(ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
@ -0,0 +1,72 @@
|
||||
"""Wrapper around Cohere APIs."""
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
def remove_stop_tokens(text: str, stop: List[str]) -> str:
|
||||
"""Remove stop tokens, should they occur at end."""
|
||||
for s in stop:
|
||||
if text.endswith(s):
|
||||
return text[: -len(s)]
|
||||
return text
|
||||
|
||||
|
||||
class Cohere(BaseModel, LLM):
|
||||
"""Wrapper around Cohere large language models."""
|
||||
|
||||
client: Any
|
||||
model: str = "gptd-instruct-tft"
|
||||
max_tokens: int = 256
|
||||
temperature: float = 0.6
|
||||
k: int = 0
|
||||
p: int = 1
|
||||
frequency_penalty: int = 0
|
||||
presence_penalty: int = 0
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if "COHERE_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Did not find Cohere API key, please add an environment variable"
|
||||
" `COHERE_API_KEY` which contains it."
|
||||
)
|
||||
try:
|
||||
import cohere
|
||||
|
||||
values["client"] = cohere.Client(os.environ["COHERE_API_KEY"])
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import cohere python package. "
|
||||
"Please it install it with `pip install cohere`."
|
||||
)
|
||||
return values
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to Cohere's generate endpoint."""
|
||||
response = self.client.generate(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
k=self.k,
|
||||
p=self.p,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
presence_penalty=self.presence_penalty,
|
||||
stop_sequences=stop,
|
||||
)
|
||||
text = response.generations[0].text
|
||||
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||
# In order to make this consistent with other endpoints, we strip them.
|
||||
if stop is not None:
|
||||
text = remove_stop_tokens(text, stop)
|
||||
return text
|
@ -0,0 +1,65 @@
|
||||
"""Wrapper around OpenAI APIs."""
|
||||
import os
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class OpenAI(BaseModel, LLM):
|
||||
"""Wrapper around OpenAI large language models."""
|
||||
|
||||
client: Any
|
||||
model_name: str = "text-davinci-002"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 256
|
||||
top_p: int = 1
|
||||
frequency_penalty: int = 0
|
||||
presence_penalty: int = 0
|
||||
n: int = 1
|
||||
best_of: int = 1
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Did not find OpenAI API key, please add an environment variable"
|
||||
" `OPENAI_API_KEY` which contains it."
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def default_params(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"n": self.n,
|
||||
"best_of": self.best_of,
|
||||
}
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to OpenAI's create endpoint."""
|
||||
response = self.client.create(
|
||||
model=self.model_name, prompt=prompt, stop=stop, **self.default_params
|
||||
)
|
||||
return response["choices"][0]["text"]
|
@ -0,0 +1,47 @@
|
||||
"""Prompt schema definition."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.formatting import formatter
|
||||
|
||||
_FORMATTER_MAPPING = {
|
||||
"f-string": formatter.format,
|
||||
}
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""Schema to represent a prompt for an LLM."""
|
||||
|
||||
input_variables: List[str]
|
||||
template: str
|
||||
template_format: str = "f-string"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
return _FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
input_variables = values["input_variables"]
|
||||
template = values["template"]
|
||||
template_format = values["template_format"]
|
||||
if template_format not in _FORMATTER_MAPPING:
|
||||
valid_formats = list(_FORMATTER_MAPPING)
|
||||
raise ValueError(
|
||||
f"Invalid template format. Got `{template_format}`;"
|
||||
f" should be one of {valid_formats}"
|
||||
)
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
try:
|
||||
formatter_func = _FORMATTER_MAPPING[template_format]
|
||||
formatter_func(template, **dummy_inputs)
|
||||
except KeyError:
|
||||
raise ValueError("Invalid prompt schema.")
|
||||
return values
|
@ -0,0 +1,7 @@
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = "True"
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks"]
|
@ -0,0 +1,10 @@
|
||||
version: 2
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
formats: all
|
||||
python:
|
||||
version: 3.6
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: pip
|
||||
path: .
|
@ -0,0 +1,9 @@
|
||||
-r test_requirements.txt
|
||||
black
|
||||
isort
|
||||
mypy
|
||||
flake8
|
||||
flake8-docstrings
|
||||
cohere
|
||||
openai
|
||||
google-search-results
|
@ -0,0 +1,23 @@
|
||||
"""Set up the package."""
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open(Path(__file__).absolute().parents[0] / "langchain" / "VERSION") as _f:
|
||||
__version__ = _f.read().strip()
|
||||
|
||||
with open("README.md", "r") as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name="langchain",
|
||||
version=__version__,
|
||||
packages=find_packages(),
|
||||
description="Building applications with LLMs through composability",
|
||||
install_requires=["pydantic"],
|
||||
long_description=long_description,
|
||||
license="MIT",
|
||||
url="https://github.com/hwchase17/langchain",
|
||||
include_package_data=True,
|
||||
long_description_content_type="text/markdown",
|
||||
)
|
@ -0,0 +1,3 @@
|
||||
-e .
|
||||
pytest
|
||||
pytest-dotenv
|
@ -0,0 +1 @@
|
||||
"""All tests for this package."""
|
@ -0,0 +1 @@
|
||||
"""All integration tests (tests that call out to an external API)."""
|
@ -0,0 +1 @@
|
||||
"""All integration tests for chains."""
|
@ -0,0 +1,18 @@
|
||||
"""Integration test for self ask with search."""
|
||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
|
||||
def test_self_ask_with_search() -> None:
|
||||
"""Test functionality on a prompt."""
|
||||
question = "What is the hometown of the reigning men's U.S. Open champion?"
|
||||
chain = SelfAskWithSearchChain(
|
||||
llm=OpenAI(temperature=0),
|
||||
search_chain=SerpAPIChain(),
|
||||
input_key="q",
|
||||
output_key="a",
|
||||
)
|
||||
answer = chain.run(question)
|
||||
final_answer = answer.split("\n")[-1]
|
||||
assert final_answer == "So the final answer is: El Palmar, Murcia, Spain"
|
@ -0,0 +1,9 @@
|
||||
"""Integration test for SerpAPI."""
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that call gives the correct answer."""
|
||||
chain = SerpAPIChain()
|
||||
output = chain.search("What was Obama's first name?")
|
||||
assert output == "Barack Hussein Obama II"
|
@ -0,0 +1 @@
|
||||
"""All integration tests for LLM objects."""
|
@ -0,0 +1,10 @@
|
||||
"""Test Cohere API wrapper."""
|
||||
|
||||
from langchain.llms.cohere import Cohere
|
||||
|
||||
|
||||
def test_cohere_call() -> None:
|
||||
"""Test valid call to cohere."""
|
||||
llm = Cohere(max_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
@ -0,0 +1,10 @@
|
||||
"""Test OpenAI API wrapper."""
|
||||
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
|
||||
def test_cohere_call() -> None:
|
||||
"""Test valid call to cohere."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
@ -0,0 +1 @@
|
||||
"""All unit tests (lightweight tests)."""
|
@ -0,0 +1 @@
|
||||
"""Tests for correct functioning of chains."""
|
@ -0,0 +1,50 @@
|
||||
"""Test logic on base chain class."""
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
"""Fake chain class for testing purposes."""
|
||||
|
||||
be_correct: bool = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input key of foo."""
|
||||
return ["foo"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output key of bar."""
|
||||
return ["bar"]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
if self.be_correct:
|
||||
return {"bar": "baz"}
|
||||
else:
|
||||
return {"baz": "bar"}
|
||||
|
||||
|
||||
def test_bad_inputs() -> None:
|
||||
"""Test errors are raised if input keys are not found."""
|
||||
chain = FakeChain()
|
||||
with pytest.raises(ValueError):
|
||||
chain({"foobar": "baz"})
|
||||
|
||||
|
||||
def test_bad_outputs() -> None:
|
||||
"""Test errors are raised if outputs keys are not found."""
|
||||
chain = FakeChain(be_correct=False)
|
||||
with pytest.raises(ValueError):
|
||||
chain({"foo": "baz"})
|
||||
|
||||
|
||||
def test_correct_call() -> None:
|
||||
"""Test correct call of fake chain."""
|
||||
chain = FakeChain()
|
||||
output = chain({"foo": "bar"})
|
||||
assert output == {"foo": "bar", "bar": "baz"}
|
@ -0,0 +1,36 @@
|
||||
"""Test LLM chain."""
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompt import Prompt
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_chain() -> LLMChain:
|
||||
"""Fake LLM chain for testing purposes."""
|
||||
prompt = Prompt(input_variables=["bar"], template="This is a {bar}:")
|
||||
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||
|
||||
|
||||
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test error is raised if inputs are missing."""
|
||||
with pytest.raises(ValueError):
|
||||
fake_llm_chain({"foo": "bar"})
|
||||
|
||||
|
||||
def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test valid call of LLM chain."""
|
||||
output = fake_llm_chain({"bar": "baz"})
|
||||
assert output == {"bar": "baz", "text1": "foo"}
|
||||
|
||||
# Test with stop words.
|
||||
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
|
||||
# Response should be `bar` now.
|
||||
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||
|
||||
|
||||
def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test predict method works."""
|
||||
output = fake_llm_chain.predict(bar="baz")
|
||||
assert output == "foo"
|
@ -0,0 +1,40 @@
|
||||
"""Test LLM Math functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.llm_math.prompt import _PROMPT_TEMPLATE
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_math_chain() -> LLMMathChain:
|
||||
"""Fake LLM Math chain for testing."""
|
||||
complex_question = _PROMPT_TEMPLATE.format(question="What is the square root of 2?")
|
||||
queries = {
|
||||
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
|
||||
complex_question: "```python\nprint(2**.5)\n```",
|
||||
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "What is 1 plus 1?"
|
||||
output = fake_llm_math_chain.run(question)
|
||||
assert output == "Answer: 2"
|
||||
|
||||
|
||||
def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
"""Test complex question that should need python."""
|
||||
question = "What is the square root of 2?"
|
||||
output = fake_llm_math_chain.run(question)
|
||||
assert output == f"Answer: {2**.5}\n"
|
||||
|
||||
|
||||
def test_error(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
"""Test question that raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
fake_llm_math_chain.run("foo")
|
@ -0,0 +1,15 @@
|
||||
"""Test python chain."""
|
||||
|
||||
from langchain.chains.python import PythonChain
|
||||
|
||||
|
||||
def test_functionality() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonChain(input_key="code1", output_key="output1")
|
||||
code = "print(1 + 1)"
|
||||
output = chain({"code1": code})
|
||||
assert output == {"code1": code, "output1": "2\n"}
|
||||
|
||||
# Test with the more user-friendly interface.
|
||||
simple_output = chain.run(code)
|
||||
assert simple_output == "2\n"
|
@ -0,0 +1,5 @@
|
||||
{
|
||||
"input_variables": ["foo"],
|
||||
"template": "This is a {foo} test.",
|
||||
"bad_var": 1
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
{
|
||||
"input_variables": ["foo"]
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
{
|
||||
"input_variables": ["foo"],
|
||||
"template": "This is a {foo} test."
|
||||
}
|
@ -0,0 +1 @@
|
||||
"""All unit tests for LLM objects."""
|
@ -0,0 +1,21 @@
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
from typing import List, Mapping, Optional
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
def __init__(self, queries: Optional[Mapping] = None):
|
||||
"""Initialize with optional lookup of queries."""
|
||||
self._queries = queries
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Return `foo` if no stop words, otherwise `bar`."""
|
||||
if self._queries is not None:
|
||||
return self._queries[prompt]
|
||||
if stop is None:
|
||||
return "foo"
|
||||
else:
|
||||
return "bar"
|
@ -0,0 +1,17 @@
|
||||
"""Test helper functions for Cohere API."""
|
||||
|
||||
from langchain.llms.cohere import remove_stop_tokens
|
||||
|
||||
|
||||
def test_remove_stop_tokens() -> None:
|
||||
"""Test removing stop tokens when they occur."""
|
||||
text = "foo bar baz"
|
||||
output = remove_stop_tokens(text, ["moo", "baz"])
|
||||
assert output == "foo bar "
|
||||
|
||||
|
||||
def test_remove_stop_tokens_none() -> None:
|
||||
"""Test removing stop tokens when they do not occur."""
|
||||
text = "foo bar baz"
|
||||
output = remove_stop_tokens(text, ["moo"])
|
||||
assert output == "foo bar baz"
|
@ -0,0 +1,26 @@
|
||||
"""Test formatting functionality."""
|
||||
import pytest
|
||||
|
||||
from langchain.formatting import formatter
|
||||
|
||||
|
||||
def test_valid_formatting() -> None:
|
||||
"""Test formatting works as expected."""
|
||||
template = "This is a {foo} test."
|
||||
output = formatter.format(template, foo="good")
|
||||
expected_output = "This is a good test."
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_does_not_allow_args() -> None:
|
||||
"""Test formatting raises error when args are provided."""
|
||||
template = "This is a {} test."
|
||||
with pytest.raises(ValueError):
|
||||
formatter.format(template, "good")
|
||||
|
||||
|
||||
def test_does_not_allow_extra_kwargs() -> None:
|
||||
"""Test formatting does not allow extra key word arguments."""
|
||||
template = "This is a {foo} test."
|
||||
with pytest.raises(KeyError):
|
||||
formatter.format(template, foo="good", bar="oops")
|
@ -0,0 +1,47 @@
|
||||
"""Test functionality related to prompts."""
|
||||
import pytest
|
||||
|
||||
from langchain.prompt import Prompt
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
"""Test prompts can be constructed."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
prompt = Prompt(input_variables=input_variables, template=template)
|
||||
assert prompt.template == template
|
||||
assert prompt.input_variables == input_variables
|
||||
|
||||
|
||||
def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_invalid_template_format() -> None:
|
||||
"""Test initializing a prompt with invalid template format."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(
|
||||
input_variables=input_variables, template=template, template_format="bar"
|
||||
)
|
Loading…
Reference in New Issue