mirror of https://github.com/hwchase17/langchain
add anthropic functions wrapper (#8475)
a cheeky wrapper around claude that adds in function calling support (kind of, hence it going in experimental)pull/8490/head
parent
490ad93b3c
commit
8f14ddefdf
@ -0,0 +1,287 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5125a1e3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Anthropic Functions\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use an experimental wrapper around Anthropic that gives it the same API as OpenAI Functions."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "378be79b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_experimental.llms.anthropic_functions import AnthropicFunctions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "65499965",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize Model\n",
|
||||||
|
"\n",
|
||||||
|
"You can initialize this wrapper the same way you'd initialize ChatAnthropic"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "e1d535f6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = AnthropicFunctions(model='claude-2')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fcc9eaf4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Passing in functions\n",
|
||||||
|
"\n",
|
||||||
|
"You can now pass in functions in a similar way"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "0779c320",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"functions=[\n",
|
||||||
|
" {\n",
|
||||||
|
" \"name\": \"get_current_weather\",\n",
|
||||||
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"location\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city and state, e.g. San Francisco, CA\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"enum\": [\"celsius\", \"fahrenheit\"]\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"location\"]\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
" ]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "ad75a933",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema import HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "fc703085",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = model.predict_messages(\n",
|
||||||
|
" [HumanMessage(content=\"whats the weater in boston?\")], \n",
|
||||||
|
" functions=functions\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "04d7936a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=' ', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\", \"unit\": \"fahrenheit\"}'}}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0072fdba",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using for extraction\n",
|
||||||
|
"\n",
|
||||||
|
"You can now use this for extraction."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "7af5c567",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import create_extraction_chain\n",
|
||||||
|
"schema = {\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"name\": {\"type\": \"string\"},\n",
|
||||||
|
" \"height\": {\"type\": \"integer\"},\n",
|
||||||
|
" \"hair_color\": {\"type\": \"string\"},\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"name\", \"height\"],\n",
|
||||||
|
"}\n",
|
||||||
|
"inp = \"\"\"\n",
|
||||||
|
"Alex is 5 feet tall. Claudia is 1 feet taller Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\n",
|
||||||
|
" \"\"\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "bd01082a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = create_extraction_chain(schema, model)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "b5a23e9f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[{'name': 'Alex', 'height': '5', 'hair_color': 'blonde'},\n",
|
||||||
|
" {'name': 'Claudia', 'height': '6', 'hair_color': 'brunette'}]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(inp)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "90ec959e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using for tagging\n",
|
||||||
|
"\n",
|
||||||
|
"You can now use this for tagging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "03c1eb0d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import create_tagging_chain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "581c0ece",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"schema = {\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"sentiment\": {\"type\": \"string\"},\n",
|
||||||
|
" \"aggressiveness\": {\"type\": \"integer\"},\n",
|
||||||
|
" \"language\": {\"type\": \"string\"},\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "d9a8570e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = create_tagging_chain(schema, model)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "cf37d679",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'sentiment': 'positive', 'aggressiveness': '0', 'language': 'english'}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"this is really cool\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,206 @@
|
|||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from html.parser import HTMLParser
|
||||||
|
from typing import Any, DefaultDict, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
Callbacks,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.anthropic import ChatAnthropic
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatResult,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
prompt = """In addition to responding, you can use tools. \
|
||||||
|
You have access to the following tools.
|
||||||
|
|
||||||
|
{tools}
|
||||||
|
|
||||||
|
In order to use a tool, you can use <tool></tool> to specify the name, \
|
||||||
|
and the <tool_input></tool_input> tags to specify the parameters. \
|
||||||
|
Each parameter should be passed in as <$param_name>$value</$param_name>, \
|
||||||
|
Where $param_name is the name of the specific parameter, and $value \
|
||||||
|
is the value for that parameter.
|
||||||
|
|
||||||
|
You will then get back a response in the form <observation></observation>
|
||||||
|
For example, if you have a tool called 'search' that accepts a single \
|
||||||
|
parameter 'query' that could run a google search, in order to search \
|
||||||
|
for the weather in SF you would respond:
|
||||||
|
|
||||||
|
<tool>search</tool><tool_input><query>weather in SF</query></tool_input>
|
||||||
|
<observation>64 degrees</observation>"""
|
||||||
|
|
||||||
|
|
||||||
|
class TagParser(HTMLParser):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""A heavy-handed solution, but it's fast for prototyping.
|
||||||
|
|
||||||
|
Might be re-implemented later to restrict scope to the limited grammar, and
|
||||||
|
more efficiency.
|
||||||
|
|
||||||
|
Uses an HTML parser to parse a limited grammar that allows
|
||||||
|
for syntax of the form:
|
||||||
|
|
||||||
|
INPUT -> JUNK? VALUE*
|
||||||
|
JUNK -> JUNK_CHARACTER+
|
||||||
|
JUNK_CHARACTER -> whitespace | ,
|
||||||
|
VALUE -> <IDENTIFIER>DATA</IDENTIFIER> | OBJECT
|
||||||
|
OBJECT -> <IDENTIFIER>VALUE+</IDENTIFIER>
|
||||||
|
IDENTIFIER -> [a-Z][a-Z0-9_]*
|
||||||
|
DATA -> .*
|
||||||
|
|
||||||
|
Interprets the data to allow repetition of tags and recursion
|
||||||
|
to support representation of complex types.
|
||||||
|
|
||||||
|
^ Just another approximately wrong grammar specification.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.parse_data: DefaultDict[str, List[Any]] = defaultdict(list)
|
||||||
|
self.stack: List[DefaultDict[str, List[str]]] = [self.parse_data]
|
||||||
|
self.success = True
|
||||||
|
self.depth = 0
|
||||||
|
self.data: Optional[str] = None
|
||||||
|
|
||||||
|
def handle_starttag(self, tag: str, attrs: Any) -> None:
|
||||||
|
"""Hook when a new tag is encountered."""
|
||||||
|
self.depth += 1
|
||||||
|
self.stack.append(defaultdict(list))
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
def handle_endtag(self, tag: str) -> None:
|
||||||
|
"""Hook when a tag is closed."""
|
||||||
|
self.depth -= 1
|
||||||
|
top_of_stack = dict(self.stack.pop(-1)) # Pop the dictionary we don't need it
|
||||||
|
|
||||||
|
# If a lead node
|
||||||
|
is_leaf = self.data is not None
|
||||||
|
# Annoying to type here, code is tested, hopefully OK
|
||||||
|
value = self.data if is_leaf else top_of_stack
|
||||||
|
# Difficult to type this correctly with mypy (maybe impossible?)
|
||||||
|
# Can be nested indefinitely, so requires self referencing type
|
||||||
|
self.stack[-1][tag].append(value) # type: ignore
|
||||||
|
# Reset the data so we if we encounter a sequence of end tags, we
|
||||||
|
# don't confuse an outer end tag for belonging to a leaf node.
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
def handle_data(self, data: str) -> None:
|
||||||
|
"""Hook when handling data."""
|
||||||
|
stripped_data = data.strip()
|
||||||
|
# The only data that's allowed is whitespace or a comma surrounded by whitespace
|
||||||
|
if self.depth == 0 and stripped_data not in (",", ""):
|
||||||
|
# If this is triggered the parse should be considered invalid.
|
||||||
|
self.success = False
|
||||||
|
if stripped_data: # ignore whitespace-only strings
|
||||||
|
self.data = stripped_data
|
||||||
|
|
||||||
|
|
||||||
|
def _destrip(tool_input: Any) -> Any:
|
||||||
|
if isinstance(tool_input, dict):
|
||||||
|
return {k: _destrip(v) for k, v in tool_input.items()}
|
||||||
|
elif isinstance(tool_input, list):
|
||||||
|
if isinstance(tool_input[0], str):
|
||||||
|
if len(tool_input) == 1:
|
||||||
|
return tool_input[0]
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
elif isinstance(tool_input[0], dict):
|
||||||
|
return [_destrip(v) for v in tool_input]
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicFunctions(BaseChatModel):
|
||||||
|
model: ChatAnthropic
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
return {"model": ChatAnthropic(**values)}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
forced = False
|
||||||
|
function_call = ""
|
||||||
|
if "functions" in kwargs:
|
||||||
|
content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2))
|
||||||
|
system = SystemMessage(content=content)
|
||||||
|
messages = [system] + messages
|
||||||
|
del kwargs["functions"]
|
||||||
|
if stop is None:
|
||||||
|
stop = ["</tool_input>"]
|
||||||
|
else:
|
||||||
|
stop.append("</tool_input>")
|
||||||
|
if "function_call" in kwargs:
|
||||||
|
forced = True
|
||||||
|
function_call = kwargs["function_call"]["name"]
|
||||||
|
AIMessage(content=f"<tool>{function_call}</tool>")
|
||||||
|
del kwargs["function_call"]
|
||||||
|
else:
|
||||||
|
if "function_call" in kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"if `function_call` provided, `functions` must also be"
|
||||||
|
)
|
||||||
|
response = self.model.predict_messages(
|
||||||
|
messages, stop=stop, callbacks=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
completion = response.content
|
||||||
|
if forced:
|
||||||
|
tag_parser = TagParser()
|
||||||
|
tag_parser.feed(completion.strip() + "</tool_input>")
|
||||||
|
v1 = tag_parser.parse_data["tool_input"][0]
|
||||||
|
kwargs = {
|
||||||
|
"function_call": {
|
||||||
|
"name": function_call,
|
||||||
|
"arguments": json.dumps(_destrip(v1)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message = AIMessage(content="", additional_kwargs=kwargs)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
elif "<tool>" in completion:
|
||||||
|
tag_parser = TagParser()
|
||||||
|
tag_parser.feed(completion.strip() + "</tool_input>")
|
||||||
|
msg = completion.split("<tool>")[0]
|
||||||
|
v1 = tag_parser.parse_data["tool_input"][0]
|
||||||
|
kwargs = {
|
||||||
|
"function_call": {
|
||||||
|
"name": tag_parser.parse_data["tool"][0],
|
||||||
|
"arguments": json.dumps(_destrip(v1)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message = AIMessage(content=msg, additional_kwargs=kwargs)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
else:
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=response)])
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
*,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "anthropic_functions"
|
Loading…
Reference in New Issue