forked from Archives/langchain
initial commit
commit
18aeb72012
@ -0,0 +1,11 @@
|
|||||||
|
[flake8]
|
||||||
|
exclude =
|
||||||
|
.venv
|
||||||
|
__pycache__
|
||||||
|
notebooks
|
||||||
|
# Recommend matching the black line length (default 88),
|
||||||
|
# rather than using the flake8 default of 79:
|
||||||
|
max-line-length = 88
|
||||||
|
extend-ignore =
|
||||||
|
# See https://github.com/PyCQA/pycodestyle/issues/373
|
||||||
|
E203,
|
@ -0,0 +1,23 @@
|
|||||||
|
name: lint
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.7"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Analysing the code with our lint
|
||||||
|
run: |
|
||||||
|
make lint
|
@ -0,0 +1,23 @@
|
|||||||
|
name: test
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.7"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r test_requirements.txt
|
||||||
|
- name: Run unit tests
|
||||||
|
run: |
|
||||||
|
make tests
|
@ -0,0 +1,130 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
notebooks/
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
@ -0,0 +1,2 @@
|
|||||||
|
include langchain/VERSION
|
||||||
|
include LICENSE
|
@ -0,0 +1,17 @@
|
|||||||
|
.PHONY: format lint tests integration_tests
|
||||||
|
|
||||||
|
format:
|
||||||
|
black .
|
||||||
|
isort .
|
||||||
|
|
||||||
|
lint:
|
||||||
|
mypy .
|
||||||
|
black . --check
|
||||||
|
isort . --check
|
||||||
|
flake8 .
|
||||||
|
|
||||||
|
tests:
|
||||||
|
pytest tests/unit_tests
|
||||||
|
|
||||||
|
integration_tests:
|
||||||
|
pytest tests/integration_tests
|
@ -0,0 +1,80 @@
|
|||||||
|
# 🦜️🔗 LangChain
|
||||||
|
|
||||||
|
⚡ Building applications with LLMs through composability ⚡
|
||||||
|
|
||||||
|
[![lint](https://github.com/hwchase17/langchain/actions/workflows/lint.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [![test](https://github.com/hwchase17/langchain/actions/workflows/test.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Quick Install
|
||||||
|
|
||||||
|
`pip install langchain`
|
||||||
|
|
||||||
|
## 🤔 What is this?
|
||||||
|
|
||||||
|
Large language models (LLMs) are emerging as a transformative technology, enabling
|
||||||
|
developers to build applications that they previously could not.
|
||||||
|
But using these LLMs in isolation is often not enough to
|
||||||
|
create a truly powerful app - the real power comes when you are able to
|
||||||
|
combine them with other sources of computation or knowledge.
|
||||||
|
|
||||||
|
This library is aimed at assisting in the development of those types of applications.
|
||||||
|
It aims to create:
|
||||||
|
1. a comprehensive collection of pieces you would ever want to combine
|
||||||
|
2. a flexible interface for combining pieces into a single comprehensive "chain"
|
||||||
|
3. a schema for easily saving and sharing those chains
|
||||||
|
|
||||||
|
## 🚀 What can I do with this
|
||||||
|
|
||||||
|
This project was largely inspired by a few projects seen on Twitter for which we thought it would make sense to have more explicit tooling. A lot of the initial functionality was done in an attempt to recreate those. Those are:
|
||||||
|
|
||||||
|
**[Self-ask-with-search](https://ofir.io/self-ask.pdf)**
|
||||||
|
|
||||||
|
To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/self_ask_with_search.ipynb).
|
||||||
|
|
||||||
|
```
|
||||||
|
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
||||||
|
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
search = SerpAPIChain()
|
||||||
|
|
||||||
|
self_ask_with_search = SelfAskWithSearchChain(llm=llm, search_chain=search)
|
||||||
|
|
||||||
|
self_ask_with_search.run("What is the hometown of the reigning men's U.S. Open champion?")
|
||||||
|
```
|
||||||
|
|
||||||
|
**[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)**
|
||||||
|
|
||||||
|
To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/llm_math.ipynb).
|
||||||
|
|
||||||
|
```
|
||||||
|
from langchain import OpenAI, LLMMathChain
|
||||||
|
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
llm_math = LLMMathChain(llm=llm)
|
||||||
|
|
||||||
|
llm_math.run("How many of the integers between 0 and 99 inclusive are divisible by 8?")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Generic Prompting**
|
||||||
|
|
||||||
|
You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/simple_prompts.ipynb).
|
||||||
|
|
||||||
|
```
|
||||||
|
from langchain import Prompt, OpenAI, LLMChain
|
||||||
|
|
||||||
|
template = """Question: {question}
|
||||||
|
|
||||||
|
Answer: Let's think step by step."""
|
||||||
|
prompt = Prompt(template=template, input_variables=["question"])
|
||||||
|
llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0))
|
||||||
|
|
||||||
|
question = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
|
||||||
|
|
||||||
|
llm_chain.predict(question=question)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📖 Documentation
|
||||||
|
|
||||||
|
The above examples are probably the most user friendly documentation that exists,
|
||||||
|
but full API docs can be found [here](https://langchain.readthedocs.io/en/latest/?).
|
@ -0,0 +1,21 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line, and also
|
||||||
|
# from the environment for the first two.
|
||||||
|
SPHINXOPTS ?=
|
||||||
|
SPHINXBUILD ?= sphinx-build
|
||||||
|
SPHINXAUTOBUILD ?= sphinx-autobuild
|
||||||
|
SOURCEDIR = .
|
||||||
|
BUILDDIR = _build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@ -0,0 +1,65 @@
|
|||||||
|
"""Configuration file for the Sphinx documentation builder."""
|
||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# This file only contains a selection of the most common options. For a full
|
||||||
|
# list see the documentation:
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||||
|
|
||||||
|
# -- Path setup --------------------------------------------------------------
|
||||||
|
|
||||||
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
|
#
|
||||||
|
# import os
|
||||||
|
# import sys
|
||||||
|
# sys.path.insert(0, os.path.abspath('.'))
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = "LangChain"
|
||||||
|
copyright = "2022, Harrison Chase"
|
||||||
|
author = "Harrison Chase"
|
||||||
|
|
||||||
|
version = langchain.__version__
|
||||||
|
release = langchain.__version__
|
||||||
|
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
|
# ones.
|
||||||
|
extensions = [
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
"sphinx.ext.autodoc.typehints",
|
||||||
|
"sphinx.ext.autosummary",
|
||||||
|
"sphinx.ext.napoleon",
|
||||||
|
]
|
||||||
|
|
||||||
|
# autodoc_typehints = "signature"
|
||||||
|
autodoc_typehints = "description"
|
||||||
|
|
||||||
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
templates_path = ["_templates"]
|
||||||
|
|
||||||
|
# List of patterns, relative to source directory, that match files and
|
||||||
|
# directories to ignore when looking for source files.
|
||||||
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
# a list of builtin themes.
|
||||||
|
#
|
||||||
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
# html_theme = "sphinx_typlog_theme"
|
||||||
|
|
||||||
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
|
html_static_path = ["_static"]
|
@ -0,0 +1,10 @@
|
|||||||
|
Welcome to LangChain
|
||||||
|
==========================
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: User API
|
||||||
|
|
||||||
|
modules/prompt
|
||||||
|
modules/llms
|
||||||
|
modules/chains
|
@ -0,0 +1,35 @@
|
|||||||
|
@ECHO OFF
|
||||||
|
|
||||||
|
pushd %~dp0
|
||||||
|
|
||||||
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
|
if "%SPHINXBUILD%" == "" (
|
||||||
|
set SPHINXBUILD=sphinx-build
|
||||||
|
)
|
||||||
|
set SOURCEDIR=.
|
||||||
|
set BUILDDIR=_build
|
||||||
|
|
||||||
|
if "%1" == "" goto help
|
||||||
|
|
||||||
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
|
if errorlevel 9009 (
|
||||||
|
echo.
|
||||||
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
|
echo.may add the Sphinx directory to PATH.
|
||||||
|
echo.
|
||||||
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
|
echo.http://sphinx-doc.org/
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
goto end
|
||||||
|
|
||||||
|
:help
|
||||||
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
|
||||||
|
:end
|
||||||
|
popd
|
@ -0,0 +1,7 @@
|
|||||||
|
:mod:`langchain.chains`
|
||||||
|
=======================
|
||||||
|
|
||||||
|
.. automodule:: langchain.chains
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
@ -0,0 +1,6 @@
|
|||||||
|
:mod:`langchain.llms`
|
||||||
|
=======================
|
||||||
|
|
||||||
|
.. automodule:: langchain.llms
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
@ -0,0 +1,6 @@
|
|||||||
|
:mod:`langchain.prompt`
|
||||||
|
=======================
|
||||||
|
|
||||||
|
.. automodule:: langchain.prompt
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
@ -0,0 +1,4 @@
|
|||||||
|
sphinx==4.5.0
|
||||||
|
sphinx-autobuild==2021.3.14
|
||||||
|
sphinx_rtd_theme==1.0.0
|
||||||
|
sphinx-typlog-theme==0.8.0
|
@ -0,0 +1,59 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "44e9ba31",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Answer: 13\\n'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain import OpenAI, LLMMathChain\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI(temperature=0)\n",
|
||||||
|
"llm_math = LLMMathChain(llm=llm)\n",
|
||||||
|
"\n",
|
||||||
|
"llm_math.run(\"How many of the integers between 0 and 99 inclusive are divisible by 8?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f62f0c75",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,74 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "7e3b513e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"Are follow up questions needed here:\u001b[102m Yes.\n",
|
||||||
|
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[106mCarlos Alcaraz\u001b[0m.\u001b[102m\n",
|
||||||
|
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[106mEl Palmar, Murcia, Spain\u001b[0m.\u001b[102m\n",
|
||||||
|
"So the final answer is: El Palmar, Murcia, Spain\u001b[0m"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"What is the hometown of the reigning men's U.S. Open champion?\\nAre follow up questions needed here: Yes.\\nFollow up: Who is the reigning men's U.S. Open champion?\\nIntermediate answer: Carlos Alcaraz.\\nFollow up: Where is Carlos Alcaraz from?\\nIntermediate answer: El Palmar, Murcia, Spain.\\nSo the final answer is: El Palmar, Murcia, Spain\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI(temperature=0)\n",
|
||||||
|
"search = SerpAPIChain()\n",
|
||||||
|
"\n",
|
||||||
|
"self_ask_with_search = SelfAskWithSearchChain(llm=llm, search_chain=search)\n",
|
||||||
|
"\n",
|
||||||
|
"self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6195fc82",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "51a54c4d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"' The year Justin Beiber was born was 1994. In 1994, the Dallas Cowboys won the Super Bowl.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain import Prompt, OpenAI, LLMChain\n",
|
||||||
|
"\n",
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"prompt = Prompt(template=template, input_variables=[\"question\"])\n",
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0))\n",
|
||||||
|
"\n",
|
||||||
|
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.predict(question=question)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "03dd6918",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
0.0.1
|
@ -0,0 +1,27 @@
|
|||||||
|
"""Main entrypoint into package."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
|
||||||
|
__version__ = _f.read().strip()
|
||||||
|
|
||||||
|
from langchain.chains import (
|
||||||
|
LLMChain,
|
||||||
|
LLMMathChain,
|
||||||
|
PythonChain,
|
||||||
|
SelfAskWithSearchChain,
|
||||||
|
SerpAPIChain,
|
||||||
|
)
|
||||||
|
from langchain.llms import Cohere, OpenAI
|
||||||
|
from langchain.prompt import Prompt
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMChain",
|
||||||
|
"LLMMathChain",
|
||||||
|
"PythonChain",
|
||||||
|
"SelfAskWithSearchChain",
|
||||||
|
"SerpAPIChain",
|
||||||
|
"Cohere",
|
||||||
|
"OpenAI",
|
||||||
|
"Prompt",
|
||||||
|
]
|
@ -0,0 +1,14 @@
|
|||||||
|
"""Chains are easily reusable components which can be linked together."""
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
|
from langchain.chains.python import PythonChain
|
||||||
|
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMChain",
|
||||||
|
"LLMMathChain",
|
||||||
|
"PythonChain",
|
||||||
|
"SelfAskWithSearchChain",
|
||||||
|
"SerpAPIChain",
|
||||||
|
]
|
@ -0,0 +1,41 @@
|
|||||||
|
"""Base interface that all chains should implement."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
class Chain(ABC):
|
||||||
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys this chain expects."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output keys this chain expects."""
|
||||||
|
|
||||||
|
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
|
||||||
|
"""Check that all inputs are present."""
|
||||||
|
missing_keys = set(self.input_keys).difference(inputs)
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||||
|
|
||||||
|
def _validate_outputs(self, outputs: Dict[str, str]) -> None:
|
||||||
|
if set(outputs) != set(self.output_keys):
|
||||||
|
raise ValueError(
|
||||||
|
f"Did not get output keys that were expected. "
|
||||||
|
f"Got: {set(outputs)}. Expected: {set(self.output_keys)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
"""Run the logic of this chain and return the output."""
|
||||||
|
|
||||||
|
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Run the logic of this chain and add to output."""
|
||||||
|
self._validate_inputs(inputs)
|
||||||
|
outputs = self._run(inputs)
|
||||||
|
self._validate_outputs(outputs)
|
||||||
|
return {**inputs, **outputs}
|
@ -0,0 +1,46 @@
|
|||||||
|
"""Chain that just formats a prompt and calls an LLM."""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
class LLMChain(Chain, BaseModel):
|
||||||
|
"""Chain to run queries against LLMs."""
|
||||||
|
|
||||||
|
prompt: Prompt
|
||||||
|
llm: LLM
|
||||||
|
output_key: str = "text"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects."""
|
||||||
|
return self.prompt.input_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key."""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||||
|
prompt = self.prompt.format(**selected_inputs)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if "stop" in inputs:
|
||||||
|
kwargs["stop"] = inputs["stop"]
|
||||||
|
response = self.llm(prompt, **kwargs)
|
||||||
|
return {self.output_key: response}
|
||||||
|
|
||||||
|
def predict(self, **kwargs: Any) -> str:
|
||||||
|
"""More user-friendly interface for interacting with LLMs."""
|
||||||
|
return self(kwargs)[self.output_key]
|
@ -0,0 +1,4 @@
|
|||||||
|
"""Chain that interprets a prompt and executes python code to do math.
|
||||||
|
|
||||||
|
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
|
||||||
|
"""
|
@ -0,0 +1,57 @@
|
|||||||
|
"""Chain that interprets a prompt and executes python code to do math."""
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.llm_math.prompt import PROMPT
|
||||||
|
from langchain.chains.python import PythonChain
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMathChain(Chain, BaseModel):
|
||||||
|
"""Chain that interprets a prompt and executes python code to do math."""
|
||||||
|
|
||||||
|
llm: LLM
|
||||||
|
verbose: bool = False
|
||||||
|
input_key: str = "question"
|
||||||
|
output_key: str = "answer"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key."""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key."""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||||
|
python_executor = PythonChain()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
t = llm_executor.predict(question=question, stop=["```output"]).strip()
|
||||||
|
if t.startswith("```python"):
|
||||||
|
code = t[9:-4]
|
||||||
|
if self.verbose:
|
||||||
|
print("[DEBUG] evaluating code")
|
||||||
|
print(code)
|
||||||
|
output = python_executor.run(code)
|
||||||
|
answer = "Answer: " + output
|
||||||
|
elif t.startswith("Answer:"):
|
||||||
|
answer = t
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown format from LLM: {t}")
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
def run(self, question: str) -> str:
|
||||||
|
"""More user-friendly interface for interfacing with LLM math."""
|
||||||
|
return self({self.input_key: question})[self.output_key]
|
@ -0,0 +1,40 @@
|
|||||||
|
"""Chain that runs python code.
|
||||||
|
|
||||||
|
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from io import StringIO
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
|
class PythonChain(Chain, BaseModel):
|
||||||
|
"""Chain to run python code."""
|
||||||
|
|
||||||
|
input_key: str = "code"
|
||||||
|
output_key: str = "output"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key."""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return output key."""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
sys.stdout = mystdout = StringIO()
|
||||||
|
exec(inputs[self.input_key])
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
output = mystdout.getvalue()
|
||||||
|
return {self.output_key: output}
|
||||||
|
|
||||||
|
def run(self, code: str) -> str:
|
||||||
|
"""More user-friendly interface for interfacing with python."""
|
||||||
|
return self({self.input_key: code})[self.output_key]
|
@ -0,0 +1,4 @@
|
|||||||
|
"""Chain that does self ask with search.
|
||||||
|
|
||||||
|
Heavily borrowed from https://github.com/ofirpress/self-ask
|
||||||
|
"""
|
@ -0,0 +1,142 @@
|
|||||||
|
"""Chain that does self ask with search."""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.self_ask_with_search.prompt import PROMPT
|
||||||
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(generated: str) -> str:
|
||||||
|
"""Extract answer from text."""
|
||||||
|
if "\n" not in generated:
|
||||||
|
last_line = generated
|
||||||
|
else:
|
||||||
|
last_line = generated.split("\n")[-1]
|
||||||
|
|
||||||
|
if ":" not in last_line:
|
||||||
|
after_colon = last_line
|
||||||
|
else:
|
||||||
|
after_colon = generated.split(":")[-1]
|
||||||
|
|
||||||
|
if " " == after_colon[0]:
|
||||||
|
after_colon = after_colon[1:]
|
||||||
|
if "." == after_colon[-1]:
|
||||||
|
after_colon = after_colon[:-1]
|
||||||
|
|
||||||
|
return after_colon
|
||||||
|
|
||||||
|
|
||||||
|
def extract_question(generated: str, followup: str) -> str:
|
||||||
|
"""Extract question from text."""
|
||||||
|
if "\n" not in generated:
|
||||||
|
last_line = generated
|
||||||
|
else:
|
||||||
|
last_line = generated.split("\n")[-1]
|
||||||
|
|
||||||
|
if followup not in last_line:
|
||||||
|
print("we probably should never get here..." + generated)
|
||||||
|
|
||||||
|
if ":" not in last_line:
|
||||||
|
after_colon = last_line
|
||||||
|
else:
|
||||||
|
after_colon = generated.split(":")[-1]
|
||||||
|
|
||||||
|
if " " == after_colon[0]:
|
||||||
|
after_colon = after_colon[1:]
|
||||||
|
if "?" != after_colon[-1]:
|
||||||
|
print("we probably should never get here..." + generated)
|
||||||
|
|
||||||
|
return after_colon
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_line(generated: str) -> str:
|
||||||
|
"""Get the last line in text."""
|
||||||
|
if "\n" not in generated:
|
||||||
|
last_line = generated
|
||||||
|
else:
|
||||||
|
last_line = generated.split("\n")[-1]
|
||||||
|
|
||||||
|
return last_line
|
||||||
|
|
||||||
|
|
||||||
|
def greenify(_input: str) -> str:
|
||||||
|
"""Add green highlighting to text."""
|
||||||
|
return "\x1b[102m" + _input + "\x1b[0m"
|
||||||
|
|
||||||
|
|
||||||
|
def yellowfy(_input: str) -> str:
|
||||||
|
"""Add yellow highlighting to text."""
|
||||||
|
return "\x1b[106m" + _input + "\x1b[0m"
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAskWithSearchChain(Chain, BaseModel):
|
||||||
|
"""Chain that does self ask with search."""
|
||||||
|
|
||||||
|
llm: LLM
|
||||||
|
search_chain: SerpAPIChain
|
||||||
|
input_key: str = "question"
|
||||||
|
output_key: str = "answer"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key."""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key."""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||||
|
intermediate = "\nIntermediate answer:"
|
||||||
|
followup = "Follow up:"
|
||||||
|
finalans = "\nSo the final answer is:"
|
||||||
|
cur_prompt = f"{question}\nAre follow up questions needed here:"
|
||||||
|
print(cur_prompt, end="")
|
||||||
|
ret_text = llm_chain.predict(input=cur_prompt, stop=[intermediate])
|
||||||
|
print(greenify(ret_text), end="")
|
||||||
|
while followup in get_last_line(ret_text):
|
||||||
|
cur_prompt += ret_text
|
||||||
|
question = extract_question(ret_text, followup)
|
||||||
|
external_answer = self.search_chain.search(question)
|
||||||
|
if external_answer is not None:
|
||||||
|
cur_prompt += intermediate + " " + external_answer + "."
|
||||||
|
print(
|
||||||
|
intermediate + " " + yellowfy(external_answer) + ".",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
ret_text = llm_chain.predict(
|
||||||
|
input=cur_prompt, stop=["\nIntermediate answer:"]
|
||||||
|
)
|
||||||
|
print(greenify(ret_text), end="")
|
||||||
|
else:
|
||||||
|
# We only get here in the very rare case that Google returns no answer.
|
||||||
|
cur_prompt += intermediate
|
||||||
|
print(intermediate + " ")
|
||||||
|
cur_prompt += llm_chain.predict(
|
||||||
|
input=cur_prompt, stop=["\n" + followup, finalans]
|
||||||
|
)
|
||||||
|
|
||||||
|
if finalans not in ret_text:
|
||||||
|
cur_prompt += finalans
|
||||||
|
print(finalans, end="")
|
||||||
|
ret_text = llm_chain.predict(input=cur_prompt, stop=["\n"])
|
||||||
|
print(greenify(ret_text), end="")
|
||||||
|
|
||||||
|
return {self.output_key: cur_prompt + ret_text}
|
||||||
|
|
||||||
|
def run(self, question: str) -> str:
|
||||||
|
"""More user-friendly interface for interfacing with self ask with search."""
|
||||||
|
return self({self.input_key: question})[self.output_key]
|
@ -0,0 +1,44 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from langchain.prompt import Prompt
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """Question: Who lived longer, Muhammad Ali or Alan Turing?
|
||||||
|
Are follow up questions needed here: Yes.
|
||||||
|
Follow up: How old was Muhammad Ali when he died?
|
||||||
|
Intermediate answer: Muhammad Ali was 74 years old when he died.
|
||||||
|
Follow up: How old was Alan Turing when he died?
|
||||||
|
Intermediate answer: Alan Turing was 41 years old when he died.
|
||||||
|
So the final answer is: Muhammad Ali
|
||||||
|
|
||||||
|
Question: When was the founder of craigslist born?
|
||||||
|
Are follow up questions needed here: Yes.
|
||||||
|
Follow up: Who was the founder of craigslist?
|
||||||
|
Intermediate answer: Craigslist was founded by Craig Newmark.
|
||||||
|
Follow up: When was Craig Newmark born?
|
||||||
|
Intermediate answer: Craig Newmark was born on December 6, 1952.
|
||||||
|
So the final answer is: December 6, 1952
|
||||||
|
|
||||||
|
Question: Who was the maternal grandfather of George Washington?
|
||||||
|
Are follow up questions needed here: Yes.
|
||||||
|
Follow up: Who was the mother of George Washington?
|
||||||
|
Intermediate answer: The mother of George Washington was Mary Ball Washington.
|
||||||
|
Follow up: Who was the father of Mary Ball Washington?
|
||||||
|
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
|
||||||
|
So the final answer is: Joseph Ball
|
||||||
|
|
||||||
|
Question: Are both the directors of Jaws and Casino Royale from the same country?
|
||||||
|
Are follow up questions needed here: Yes.
|
||||||
|
Follow up: Who is the director of Jaws?
|
||||||
|
Intermediate Answer: The director of Jaws is Steven Spielberg.
|
||||||
|
Follow up: Where is Steven Spielberg from?
|
||||||
|
Intermediate Answer: The United States.
|
||||||
|
Follow up: Who is the director of Casino Royale?
|
||||||
|
Intermediate Answer: The director of Casino Royale is Martin Campbell.
|
||||||
|
Follow up: Where is Martin Campbell from?
|
||||||
|
Intermediate Answer: New Zealand.
|
||||||
|
So the final answer is: No
|
||||||
|
|
||||||
|
Question: {input}"""
|
||||||
|
PROMPT = Prompt(
|
||||||
|
input_variables=["input"],
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
)
|
@ -0,0 +1,99 @@
|
|||||||
|
"""Chain that calls SerpAPI.
|
||||||
|
|
||||||
|
Heavily borrowed from https://github.com/ofirpress/self-ask
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
|
class HiddenPrints:
|
||||||
|
"""Context manager to hide prints."""
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
"""Open file to pipe stdout to."""
|
||||||
|
self._original_stdout = sys.stdout
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
def __exit__(self, *_: Any) -> None:
|
||||||
|
"""Close file that stdout was piped to."""
|
||||||
|
sys.stdout.close()
|
||||||
|
sys.stdout = self._original_stdout
|
||||||
|
|
||||||
|
|
||||||
|
class SerpAPIChain(Chain, BaseModel):
|
||||||
|
"""Chain that calls SerpAPI."""
|
||||||
|
|
||||||
|
search_engine: Any
|
||||||
|
input_key: str = "search_query"
|
||||||
|
output_key: str = "search_result"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the singular input key."""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the singular output key."""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
if "SERPAPI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"Did not find SerpAPI API key, please add an environment variable"
|
||||||
|
" `SERPAPI_API_KEY` which contains it."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from serpapi import GoogleSearch
|
||||||
|
|
||||||
|
values["search_engine"] = GoogleSearch
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import serpapi python package. "
|
||||||
|
"Please it install it with `pip install google-search-results`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
params = {
|
||||||
|
"api_key": os.environ["SERPAPI_API_KEY"],
|
||||||
|
"engine": "google",
|
||||||
|
"q": inputs[self.input_key],
|
||||||
|
"google_domain": "google.com",
|
||||||
|
"gl": "us",
|
||||||
|
"hl": "en",
|
||||||
|
}
|
||||||
|
with HiddenPrints():
|
||||||
|
search = self.search_engine(params)
|
||||||
|
res = search.get_dict()
|
||||||
|
|
||||||
|
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||||
|
toret = res["answer_box"]["answer"]
|
||||||
|
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||||
|
toret = res["answer_box"]["snippet"]
|
||||||
|
elif (
|
||||||
|
"answer_box" in res.keys()
|
||||||
|
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||||
|
):
|
||||||
|
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||||
|
elif "snippet" in res["organic_results"][0].keys():
|
||||||
|
toret = res["organic_results"][0]["snippet"]
|
||||||
|
else:
|
||||||
|
toret = None
|
||||||
|
return {self.output_key: toret}
|
||||||
|
|
||||||
|
def search(self, search_question: str) -> str:
|
||||||
|
"""More user-friendly interface for interfacing with search."""
|
||||||
|
return self({self.input_key: search_question})[self.output_key]
|
@ -0,0 +1,32 @@
|
|||||||
|
"""Utilities for formatting strings."""
|
||||||
|
from string import Formatter
|
||||||
|
from typing import Any, Mapping, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
class StrictFormatter(Formatter):
|
||||||
|
"""A subclass of formatter that checks for extra keys."""
|
||||||
|
|
||||||
|
def check_unused_args(
|
||||||
|
self,
|
||||||
|
used_args: Sequence[Union[int, str]],
|
||||||
|
args: Sequence,
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Check to see if extra parameters are passed."""
|
||||||
|
extra = set(kwargs).difference(used_args)
|
||||||
|
if extra:
|
||||||
|
raise KeyError(extra)
|
||||||
|
|
||||||
|
def vformat(
|
||||||
|
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""Check that no arguments are provided."""
|
||||||
|
if len(args) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"No arguments should be provided, "
|
||||||
|
"everything should be passed as keyword arguments."
|
||||||
|
)
|
||||||
|
return super().vformat(format_string, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
formatter = StrictFormatter()
|
@ -0,0 +1,5 @@
|
|||||||
|
"""Wrappers on top of large language models."""
|
||||||
|
from langchain.llms.cohere import Cohere
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
|
||||||
|
__all__ = ["Cohere", "OpenAI"]
|
@ -0,0 +1,11 @@
|
|||||||
|
"""Base interface for large language models to expose."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LLM(ABC):
|
||||||
|
"""LLM wrapper should take in a prompt and return a string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""Run the LLM on the given prompt and input."""
|
@ -0,0 +1,72 @@
|
|||||||
|
"""Wrapper around Cohere APIs."""
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def remove_stop_tokens(text: str, stop: List[str]) -> str:
|
||||||
|
"""Remove stop tokens, should they occur at end."""
|
||||||
|
for s in stop:
|
||||||
|
if text.endswith(s):
|
||||||
|
return text[: -len(s)]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere(BaseModel, LLM):
|
||||||
|
"""Wrapper around Cohere large language models."""
|
||||||
|
|
||||||
|
client: Any
|
||||||
|
model: str = "gptd-instruct-tft"
|
||||||
|
max_tokens: int = 256
|
||||||
|
temperature: float = 0.6
|
||||||
|
k: int = 0
|
||||||
|
p: int = 1
|
||||||
|
frequency_penalty: int = 0
|
||||||
|
presence_penalty: int = 0
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def template_is_valid(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
if "COHERE_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"Did not find Cohere API key, please add an environment variable"
|
||||||
|
" `COHERE_API_KEY` which contains it."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
|
||||||
|
values["client"] = cohere.Client(os.environ["COHERE_API_KEY"])
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import cohere python package. "
|
||||||
|
"Please it install it with `pip install cohere`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""Call out to Cohere's generate endpoint."""
|
||||||
|
response = self.client.generate(
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
k=self.k,
|
||||||
|
p=self.p,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
stop_sequences=stop,
|
||||||
|
)
|
||||||
|
text = response.generations[0].text
|
||||||
|
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||||
|
# In order to make this consistent with other endpoints, we strip them.
|
||||||
|
if stop is not None:
|
||||||
|
text = remove_stop_tokens(text, stop)
|
||||||
|
return text
|
@ -0,0 +1,65 @@
|
|||||||
|
"""Wrapper around OpenAI APIs."""
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAI(BaseModel, LLM):
|
||||||
|
"""Wrapper around OpenAI large language models."""
|
||||||
|
|
||||||
|
client: Any
|
||||||
|
model_name: str = "text-davinci-002"
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 256
|
||||||
|
top_p: int = 1
|
||||||
|
frequency_penalty: int = 0
|
||||||
|
presence_penalty: int = 0
|
||||||
|
n: int = 1
|
||||||
|
best_of: int = 1
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"Did not find OpenAI API key, please add an environment variable"
|
||||||
|
" `OPENAI_API_KEY` which contains it."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
values["client"] = openai.Completion
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import openai python package. "
|
||||||
|
"Please it install it with `pip install openai`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
|
return {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"n": self.n,
|
||||||
|
"best_of": self.best_of,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""Call out to OpenAI's create endpoint."""
|
||||||
|
response = self.client.create(
|
||||||
|
model=self.model_name, prompt=prompt, stop=stop, **self.default_params
|
||||||
|
)
|
||||||
|
return response["choices"][0]["text"]
|
@ -0,0 +1,47 @@
|
|||||||
|
"""Prompt schema definition."""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.formatting import formatter
|
||||||
|
|
||||||
|
_FORMATTER_MAPPING = {
|
||||||
|
"f-string": formatter.format,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(BaseModel):
|
||||||
|
"""Schema to represent a prompt for an LLM."""
|
||||||
|
|
||||||
|
input_variables: List[str]
|
||||||
|
template: str
|
||||||
|
template_format: str = "f-string"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
return _FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def template_is_valid(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that template and input variables are consistent."""
|
||||||
|
input_variables = values["input_variables"]
|
||||||
|
template = values["template"]
|
||||||
|
template_format = values["template_format"]
|
||||||
|
if template_format not in _FORMATTER_MAPPING:
|
||||||
|
valid_formats = list(_FORMATTER_MAPPING)
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid template format. Got `{template_format}`;"
|
||||||
|
f" should be one of {valid_formats}"
|
||||||
|
)
|
||||||
|
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||||
|
try:
|
||||||
|
formatter_func = _FORMATTER_MAPPING[template_format]
|
||||||
|
formatter_func(template, **dummy_inputs)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("Invalid prompt schema.")
|
||||||
|
return values
|
@ -0,0 +1,7 @@
|
|||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
ignore_missing_imports = "True"
|
||||||
|
disallow_untyped_defs = "True"
|
||||||
|
exclude = ["notebooks"]
|
@ -0,0 +1,10 @@
|
|||||||
|
version: 2
|
||||||
|
sphinx:
|
||||||
|
configuration: docs/conf.py
|
||||||
|
formats: all
|
||||||
|
python:
|
||||||
|
version: 3.6
|
||||||
|
install:
|
||||||
|
- requirements: docs/requirements.txt
|
||||||
|
- method: pip
|
||||||
|
path: .
|
@ -0,0 +1,9 @@
|
|||||||
|
-r test_requirements.txt
|
||||||
|
black
|
||||||
|
isort
|
||||||
|
mypy
|
||||||
|
flake8
|
||||||
|
flake8-docstrings
|
||||||
|
cohere
|
||||||
|
openai
|
||||||
|
google-search-results
|
@ -0,0 +1,23 @@
|
|||||||
|
"""Set up the package."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
with open(Path(__file__).absolute().parents[0] / "langchain" / "VERSION") as _f:
|
||||||
|
__version__ = _f.read().strip()
|
||||||
|
|
||||||
|
with open("README.md", "r") as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="langchain",
|
||||||
|
version=__version__,
|
||||||
|
packages=find_packages(),
|
||||||
|
description="Building applications with LLMs through composability",
|
||||||
|
install_requires=["pydantic"],
|
||||||
|
long_description=long_description,
|
||||||
|
license="MIT",
|
||||||
|
url="https://github.com/hwchase17/langchain",
|
||||||
|
include_package_data=True,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
)
|
@ -0,0 +1,3 @@
|
|||||||
|
-e .
|
||||||
|
pytest
|
||||||
|
pytest-dotenv
|
@ -0,0 +1 @@
|
|||||||
|
"""All tests for this package."""
|
@ -0,0 +1 @@
|
|||||||
|
"""All integration tests (tests that call out to an external API)."""
|
@ -0,0 +1 @@
|
|||||||
|
"""All integration tests for chains."""
|
@ -0,0 +1,18 @@
|
|||||||
|
"""Integration test for self ask with search."""
|
||||||
|
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_ask_with_search() -> None:
|
||||||
|
"""Test functionality on a prompt."""
|
||||||
|
question = "What is the hometown of the reigning men's U.S. Open champion?"
|
||||||
|
chain = SelfAskWithSearchChain(
|
||||||
|
llm=OpenAI(temperature=0),
|
||||||
|
search_chain=SerpAPIChain(),
|
||||||
|
input_key="q",
|
||||||
|
output_key="a",
|
||||||
|
)
|
||||||
|
answer = chain.run(question)
|
||||||
|
final_answer = answer.split("\n")[-1]
|
||||||
|
assert final_answer == "So the final answer is: El Palmar, Murcia, Spain"
|
@ -0,0 +1,9 @@
|
|||||||
|
"""Integration test for SerpAPI."""
|
||||||
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
|
|
||||||
|
|
||||||
|
def test_call() -> None:
|
||||||
|
"""Test that call gives the correct answer."""
|
||||||
|
chain = SerpAPIChain()
|
||||||
|
output = chain.search("What was Obama's first name?")
|
||||||
|
assert output == "Barack Hussein Obama II"
|
@ -0,0 +1 @@
|
|||||||
|
"""All integration tests for LLM objects."""
|
@ -0,0 +1,10 @@
|
|||||||
|
"""Test Cohere API wrapper."""
|
||||||
|
|
||||||
|
from langchain.llms.cohere import Cohere
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call() -> None:
|
||||||
|
"""Test valid call to cohere."""
|
||||||
|
llm = Cohere(max_tokens=10)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
@ -0,0 +1,10 @@
|
|||||||
|
"""Test OpenAI API wrapper."""
|
||||||
|
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call() -> None:
|
||||||
|
"""Test valid call to cohere."""
|
||||||
|
llm = OpenAI(max_tokens=10)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
@ -0,0 +1 @@
|
|||||||
|
"""All unit tests (lightweight tests)."""
|
@ -0,0 +1 @@
|
|||||||
|
"""Tests for correct functioning of chains."""
|
@ -0,0 +1,50 @@
|
|||||||
|
"""Test logic on base chain class."""
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChain(Chain, BaseModel):
|
||||||
|
"""Fake chain class for testing purposes."""
|
||||||
|
|
||||||
|
be_correct: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input key of foo."""
|
||||||
|
return ["foo"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output key of bar."""
|
||||||
|
return ["bar"]
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
if self.be_correct:
|
||||||
|
return {"bar": "baz"}
|
||||||
|
else:
|
||||||
|
return {"baz": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_inputs() -> None:
|
||||||
|
"""Test errors are raised if input keys are not found."""
|
||||||
|
chain = FakeChain()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain({"foobar": "baz"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_outputs() -> None:
|
||||||
|
"""Test errors are raised if outputs keys are not found."""
|
||||||
|
chain = FakeChain(be_correct=False)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain({"foo": "baz"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_correct_call() -> None:
|
||||||
|
"""Test correct call of fake chain."""
|
||||||
|
chain = FakeChain()
|
||||||
|
output = chain({"foo": "bar"})
|
||||||
|
assert output == {"foo": "bar", "bar": "baz"}
|
@ -0,0 +1,36 @@
|
|||||||
|
"""Test LLM chain."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.prompt import Prompt
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_chain() -> LLMChain:
|
||||||
|
"""Fake LLM chain for testing purposes."""
|
||||||
|
prompt = Prompt(input_variables=["bar"], template="This is a {bar}:")
|
||||||
|
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
||||||
|
"""Test error is raised if inputs are missing."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
fake_llm_chain({"foo": "bar"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||||
|
"""Test valid call of LLM chain."""
|
||||||
|
output = fake_llm_chain({"bar": "baz"})
|
||||||
|
assert output == {"bar": "baz", "text1": "foo"}
|
||||||
|
|
||||||
|
# Test with stop words.
|
||||||
|
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
|
||||||
|
# Response should be `bar` now.
|
||||||
|
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||||
|
"""Test predict method works."""
|
||||||
|
output = fake_llm_chain.predict(bar="baz")
|
||||||
|
assert output == "foo"
|
@ -0,0 +1,40 @@
|
|||||||
|
"""Test LLM Math functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
|
from langchain.chains.llm_math.prompt import _PROMPT_TEMPLATE
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_math_chain() -> LLMMathChain:
|
||||||
|
"""Fake LLM Math chain for testing."""
|
||||||
|
complex_question = _PROMPT_TEMPLATE.format(question="What is the square root of 2?")
|
||||||
|
queries = {
|
||||||
|
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
|
||||||
|
complex_question: "```python\nprint(2**.5)\n```",
|
||||||
|
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
||||||
|
}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||||
|
"""Test simple question that should not need python."""
|
||||||
|
question = "What is 1 plus 1?"
|
||||||
|
output = fake_llm_math_chain.run(question)
|
||||||
|
assert output == "Answer: 2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||||
|
"""Test complex question that should need python."""
|
||||||
|
question = "What is the square root of 2?"
|
||||||
|
output = fake_llm_math_chain.run(question)
|
||||||
|
assert output == f"Answer: {2**.5}\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_error(fake_llm_math_chain: LLMMathChain) -> None:
|
||||||
|
"""Test question that raises error."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
fake_llm_math_chain.run("foo")
|
@ -0,0 +1,15 @@
|
|||||||
|
"""Test python chain."""
|
||||||
|
|
||||||
|
from langchain.chains.python import PythonChain
|
||||||
|
|
||||||
|
|
||||||
|
def test_functionality() -> None:
|
||||||
|
"""Test correct functionality."""
|
||||||
|
chain = PythonChain(input_key="code1", output_key="output1")
|
||||||
|
code = "print(1 + 1)"
|
||||||
|
output = chain({"code1": code})
|
||||||
|
assert output == {"code1": code, "output1": "2\n"}
|
||||||
|
|
||||||
|
# Test with the more user-friendly interface.
|
||||||
|
simple_output = chain.run(code)
|
||||||
|
assert simple_output == "2\n"
|
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"input_variables": ["foo"],
|
||||||
|
"template": "This is a {foo} test.",
|
||||||
|
"bad_var": 1
|
||||||
|
}
|
@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"input_variables": ["foo"]
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"input_variables": ["foo"],
|
||||||
|
"template": "This is a {foo} test."
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
"""All unit tests for LLM objects."""
|
@ -0,0 +1,21 @@
|
|||||||
|
"""Fake LLM wrapper for testing purposes."""
|
||||||
|
from typing import List, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLLM(LLM):
|
||||||
|
"""Fake LLM wrapper for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, queries: Optional[Mapping] = None):
|
||||||
|
"""Initialize with optional lookup of queries."""
|
||||||
|
self._queries = queries
|
||||||
|
|
||||||
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""Return `foo` if no stop words, otherwise `bar`."""
|
||||||
|
if self._queries is not None:
|
||||||
|
return self._queries[prompt]
|
||||||
|
if stop is None:
|
||||||
|
return "foo"
|
||||||
|
else:
|
||||||
|
return "bar"
|
@ -0,0 +1,17 @@
|
|||||||
|
"""Test helper functions for Cohere API."""
|
||||||
|
|
||||||
|
from langchain.llms.cohere import remove_stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_stop_tokens() -> None:
|
||||||
|
"""Test removing stop tokens when they occur."""
|
||||||
|
text = "foo bar baz"
|
||||||
|
output = remove_stop_tokens(text, ["moo", "baz"])
|
||||||
|
assert output == "foo bar "
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_stop_tokens_none() -> None:
|
||||||
|
"""Test removing stop tokens when they do not occur."""
|
||||||
|
text = "foo bar baz"
|
||||||
|
output = remove_stop_tokens(text, ["moo"])
|
||||||
|
assert output == "foo bar baz"
|
@ -0,0 +1,26 @@
|
|||||||
|
"""Test formatting functionality."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.formatting import formatter
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_formatting() -> None:
|
||||||
|
"""Test formatting works as expected."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
output = formatter.format(template, foo="good")
|
||||||
|
expected_output = "This is a good test."
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_allow_args() -> None:
|
||||||
|
"""Test formatting raises error when args are provided."""
|
||||||
|
template = "This is a {} test."
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
formatter.format(template, "good")
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_allow_extra_kwargs() -> None:
|
||||||
|
"""Test formatting does not allow extra key word arguments."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
formatter.format(template, foo="good", bar="oops")
|
@ -0,0 +1,47 @@
|
|||||||
|
"""Test functionality related to prompts."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_valid() -> None:
|
||||||
|
"""Test prompts can be constructed."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
input_variables = ["foo"]
|
||||||
|
prompt = Prompt(input_variables=input_variables, template=template)
|
||||||
|
assert prompt.template == template
|
||||||
|
assert prompt.input_variables == input_variables
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_missing_input_variables() -> None:
|
||||||
|
"""Test error is raised when input variables are not provided."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
input_variables: list = []
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Prompt(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_extra_input_variables() -> None:
|
||||||
|
"""Test error is raised when there are too many input variables."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
input_variables = ["foo", "bar"]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Prompt(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_wrong_input_variables() -> None:
|
||||||
|
"""Test error is raised when name of input variable is wrong."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
input_variables = ["bar"]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Prompt(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_invalid_template_format() -> None:
|
||||||
|
"""Test initializing a prompt with invalid template format."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
input_variables = ["foo"]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Prompt(
|
||||||
|
input_variables=input_variables, template=template, template_format="bar"
|
||||||
|
)
|
Loading…
Reference in New Issue