Openapi to openai (#6658)

pull/6660/head
Davis Chase 1 year ago committed by GitHub
parent b062a3f938
commit e013459b18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e734b314",
"metadata": {},
"source": [
"# OpenAPI calls with OpenAI functions\n",
"\n",
"In this notebook we'll show how to create a chain that automatically makes calls to an API based only on an OpenAPI spec. Under the hood, we're parsing the OpenAPI spec into a JSON schema that the OpenAI functions API can handle. This allows ChatGPT to automatically select and populate the relevant API call to make for any user input. Using the output of ChatGPT we then make the actual API call, and return the result."
]
},
{
"cell_type": "markdown",
"id": "a95f510a",
"metadata": {},
"source": [
"## Query Klarna"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "08e19b64",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n"
]
}
],
"source": [
"from langchain.chains.openai_functions.openapi import get_openapi_chain\n",
"\n",
"chain = get_openapi_chain(\"https://www.klarna.com/us/shopping/public/openai/v0/api-docs/\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3959f866",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'products': [{'name': \"Tommy Hilfiger Men's Short Sleeve Button-Down Shirt\",\n",
" 'url': 'https://www.klarna.com/us/shopping/pl/cl10001/3204878580/Clothing/Tommy-Hilfiger-Men-s-Short-Sleeve-Button-Down-Shirt/?utm_source=openai&ref-site=openai_plugin',\n",
" 'price': '$26.78',\n",
" 'attributes': ['Material:Linen,Cotton',\n",
" 'Target Group:Man',\n",
" 'Color:Gray,Pink,White,Blue,Beige,Black,Turquoise',\n",
" 'Size:S,XL,M,XXL']},\n",
" {'name': \"Van Heusen Men's Long Sleeve Button-Down Shirt\",\n",
" 'url': 'https://www.klarna.com/us/shopping/pl/cl10001/3201809514/Clothing/Van-Heusen-Men-s-Long-Sleeve-Button-Down-Shirt/?utm_source=openai&ref-site=openai_plugin',\n",
" 'price': '$18.89',\n",
" 'attributes': ['Material:Cotton',\n",
" 'Target Group:Man',\n",
" 'Color:Red,Gray,White,Blue',\n",
" 'Size:XL,XXL']},\n",
" {'name': 'Brixton Bowery Flannel Shirt',\n",
" 'url': 'https://www.klarna.com/us/shopping/pl/cl10001/3202331096/Clothing/Brixton-Bowery-Flannel-Shirt/?utm_source=openai&ref-site=openai_plugin',\n",
" 'price': '$34.48',\n",
" 'attributes': ['Material:Cotton',\n",
" 'Target Group:Man',\n",
" 'Color:Gray,Blue,Black,Orange',\n",
" 'Size:XL,3XL,4XL,5XL,L,M,XXL']},\n",
" {'name': 'Cubavera Four Pocket Guayabera Shirt',\n",
" 'url': 'https://www.klarna.com/us/shopping/pl/cl10001/3202055522/Clothing/Cubavera-Four-Pocket-Guayabera-Shirt/?utm_source=openai&ref-site=openai_plugin',\n",
" 'price': '$23.22',\n",
" 'attributes': ['Material:Polyester,Cotton',\n",
" 'Target Group:Man',\n",
" 'Color:Red,White,Blue,Black',\n",
" 'Size:S,XL,L,M,XXL']},\n",
" {'name': 'Theory Sylvain Shirt - Eclipse',\n",
" 'url': 'https://www.klarna.com/us/shopping/pl/cl10001/3202028254/Clothing/Theory-Sylvain-Shirt-Eclipse/?utm_source=openai&ref-site=openai_plugin',\n",
" 'price': '$86.01',\n",
" 'attributes': ['Material:Polyester,Cotton',\n",
" 'Target Group:Man',\n",
" 'Color:Blue',\n",
" 'Size:S,XL,XS,L,M,XXL']}]}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"What are some options for a men's large blue button down shirt\")"
]
},
{
"cell_type": "markdown",
"id": "6f648c77",
"metadata": {},
"source": [
"## Query a translation service"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bf6cd695",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n",
"Attempting to load an OpenAPI 3.0.1 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n"
]
}
],
"source": [
"chain = get_openapi_chain(\"https://api.speak.com/openapi.yaml\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1ba51609",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'explanation': '<translation language=\"None\" context=\"None\">\\nNone\\n</translation>\\n\\n<alternatives context=\"None\">\\n1. \"N/A\" *(Formal - used in professional settings to indicate that the answer is not applicable)*\\n2. \"I don\\'t have an answer for that\" *(Neutral - commonly used when one does not know the answer to a question)*\\n3. \"I\\'m not sure\" *(Neutral - similar to the above alternative, used when one is unsure of the answer)*\\n</alternatives>\\n\\n<example-convo language=\"None\">\\n<context>None</context>\\n* Tom: \"Do you know what time the concert starts?\"\\n* Sarah: \"I\\'m sorry, I don\\'t have an answer for that.\"\\n</example-convo>\\n\\n*[Report an issue or leave feedback](https://speak.com/chatgpt?rid=p8i6p14duafpctg4ve7tm48z})*',\n",
" 'extra_response_instructions': 'Use all information in the API response and fully render all Markdown.\\nAlways end your response with a link to report an issue or leave feedback on the plugin.'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"How would you say no thanks in Russian\")"
]
},
{
"cell_type": "markdown",
"id": "4923a291",
"metadata": {},
"source": [
"## Query XKCD"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a9198f62",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Attempting to load an OpenAPI 3.0.0 spec. This may result in degraded performance. Convert your OpenAPI spec to 3.1.* spec for better support.\n"
]
}
],
"source": [
"chain = get_openapi_chain(\"https://gist.githubusercontent.com/roaldnefs/053e505b2b7a807290908fe9aa3e1f00/raw/0a212622ebfef501163f91e23803552411ed00e4/openapi.yaml\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3110c398",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..\n"
]
},
{
"data": {
"text/plain": [
"{'month': '6',\n",
" 'num': 2793,\n",
" 'link': '',\n",
" 'year': '2023',\n",
" 'news': '',\n",
" 'safe_title': 'Garden Path Sentence',\n",
" 'transcript': '',\n",
" 'alt': 'Arboretum Owner Denied Standing in Garden Path Suit on Grounds Grounds Appealing Appealing',\n",
" 'img': 'https://imgs.xkcd.com/comics/garden_path_sentence.png',\n",
" 'title': 'Garden Path Sentence',\n",
" 'day': '23'}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"What's today's comic?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,242 @@
import json
import re
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import requests
from openapi_schema_pydantic import Parameter
from requests import Response
from langchain import BasePromptTemplate, LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.tools import APIOperation
from langchain.utilities.openapi import OpenAPISpec
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:
summary = getattr(o, "summary", None)
description = getattr(o, "description", None)
if prefer_short:
return summary or description
return description or summary
def _format_url(url: str, path_params: dict) -> str:
expected_path_param = re.findall(r"{(.*?)}", url)
new_params = {}
for param in expected_path_param:
clean_param = param.lstrip(".;").rstrip("*")
val = path_params[clean_param]
if isinstance(val, list):
if param[0] == ".":
sep = "." if param[-1] == "*" else ","
new_val = "." + sep.join(val)
elif param[0] == ";":
sep = f"{clean_param}=" if param[-1] == "*" else ","
new_val = f"{clean_param}=" + sep.join(val)
else:
new_val = ",".join(val)
elif isinstance(val, dict):
kv_sep = "=" if param[-1] == "*" else ","
kv_strs = [kv_sep.join((k, v)) for k, v in val.items()]
if param[0] == ".":
sep = "."
new_val = "."
elif param[0] == ";":
sep = ";"
new_val = ";"
else:
sep = ","
new_val = ""
new_val += sep.join(kv_strs)
else:
if param[0] == ".":
new_val = f".{val}"
elif param[0] == ";":
new_val = f";{clean_param}={val}"
else:
new_val = val
new_params[param] = new_val
return url.format(**new_params)
def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -> dict:
properties = {}
required = []
for p in params:
if p.param_schema:
schema = spec.get_schema(p.param_schema)
else:
media_type_schema = list(p.content.values())[0].media_type_schema # type: ignore # noqa: E501
schema = spec.get_schema(media_type_schema)
if p.description and not schema.description:
schema.description = p.description
properties[p.name] = schema.dict(exclude_none=True)
if p.required:
required.append(p.name)
return {"type": "object", "properties": properties, "required": required}
def openapi_spec_to_openai_fn(
spec: OpenAPISpec,
) -> Tuple[List[Dict[str, Any]], Callable]:
"""Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI
functions.
Args:
spec: OpenAPI spec to convert.
Returns:
Tuple of the OpenAI functions JSON schema and a default function for executing
a request based on the OpenAI function schema.
"""
if not spec.paths:
return [], lambda: None
functions = []
_name_to_call_map = {}
for path in spec.paths:
path_params = {
(p.name, p.param_in): p for p in spec.get_parameters_for_path(path)
}
for method in spec.get_methods_for_path(path):
request_args = {}
op = spec.get_operation(path, method)
op_params = path_params.copy()
for param in spec.get_parameters_for_operation(op):
op_params[(param.name, param.param_in)] = param
params_by_type = defaultdict(list)
for name_loc, p in op_params.items():
params_by_type[name_loc[1]].append(p)
param_loc_to_arg_name = {
"query": "params",
"header": "headers",
"cookie": "cookies",
"path": "path_params",
}
for param_loc, arg_name in param_loc_to_arg_name.items():
if params_by_type[param_loc]:
request_args[arg_name] = _openapi_params_to_json_schema(
params_by_type[param_loc], spec
)
request_body = spec.get_request_body_for_operation(op)
# TODO: Support more MIME types.
if request_body and request_body.content:
media_types = []
for media_type in request_body.content.values():
if media_type.media_type_schema:
schema = spec.get_schema(media_type.media_type_schema)
media_types.append(schema.dict(exclude_none=True))
if len(media_types) == 1:
request_args["data"] = media_types[0]
elif len(media_types) > 1:
request_args["data"] = {"anyOf": media_types}
api_op = APIOperation.from_openapi_spec(spec, path, method)
fn = {
"name": api_op.operation_id,
"description": api_op.description,
"parameters": {
"type": "object",
"properties": request_args,
},
}
functions.append(fn)
_name_to_call_map[fn["name"]] = {
"method": method,
"url": api_op.base_url + api_op.path,
}
def default_call_api(name: str, fn_args: dict, **kwargs: Any) -> Any:
method = _name_to_call_map[name]["method"]
url = _name_to_call_map[name]["url"]
path_params = fn_args.pop("path_params", {})
_format_url(url, path_params)
if "data" in fn_args and isinstance(fn_args["data"], dict):
fn_args["data"] = json.dumps(fn_args["data"])
_kwargs = {**fn_args, **kwargs}
return requests.request(method, url, **_kwargs)
return functions, default_call_api
class SimpleRequestChain(Chain):
request_method: Callable
output_key: str = "response"
input_key: str = "function"
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
name = inputs["function"].pop("name")
args = inputs["function"].pop("arguments")
api_response: Response = self.request_method(name, args)
if api_response.status_code != 200:
response = (
f"{api_response.status_code}: {api_response.reason}"
+ f"\nFor {name} "
+ f"Called with args: {args['params']}"
)
else:
try:
response = api_response.json()
except Exception: # noqa: E722
response = api_response.text
return {self.output_key: response}
def get_openapi_chain(
spec: Union[OpenAPISpec, str],
llm: Optional[BaseLanguageModel] = None,
prompt: Optional[BasePromptTemplate] = None,
request_chain: Optional[Chain] = None,
) -> SequentialChain:
if isinstance(spec, str):
for conversion in (
OpenAPISpec.from_url,
OpenAPISpec.from_file,
OpenAPISpec.from_text,
):
try:
spec = conversion(spec) # type: ignore[arg-type]
break
except Exception: # noqa: E722
pass
if isinstance(spec, str):
raise ValueError(f"Unable to parse spec from source {spec}")
openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec)
llm = llm or ChatOpenAI(
model="gpt-3.5-turbo-0613",
)
prompt = prompt or ChatPromptTemplate.from_template(
"Use the provided API's to respond to this user query:\n\n{query}"
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={"functions": openai_fns},
output_parser=JsonOutputFunctionsParser(args_only=False),
output_key="function",
)
request_chain = request_chain or SimpleRequestChain(request_method=call_api_fn)
return SequentialChain(
chains=[llm_chain, request_chain],
input_variables=llm_chain.input_keys,
output_variables=["response"],
)

@ -5,6 +5,8 @@ from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
args_only: bool = True
def parse_result(self, result: List[Generation]) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
@ -17,13 +19,18 @@ class OutputFunctionsParser(BaseLLMOutputParser[Any]):
except ValueError as exc:
raise ValueError(f"Could not parse function call: {exc}")
return func_call["arguments"]
if self.args_only:
return func_call["arguments"]
return func_call
class JsonOutputFunctionsParser(OutputFunctionsParser):
def parse_result(self, result: List[Generation]) -> Any:
_args = super().parse_result(result)
return json.loads(_args)
func = super().parse_result(result)
if self.args_only:
return json.loads(func)
func["arguments"] = json.loads(func["arguments"])
return func
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):

@ -59,7 +59,7 @@ SCHEMA_TYPE = Union[str, Type, tuple, None, Enum]
class APIPropertyBase(BaseModel):
"""Base model for an API property."""
# The name of the parameter is required and is case sensitive.
# The name of the parameter is required and is case-sensitive.
# If "in" is "path", the "name" field must correspond to a template expression
# within the path field in the Paths Object.
# If "in" is "header" and the "name" field is "Accept", "Content-Type",

@ -1,269 +1,2 @@
"""Utility functions for parsing an OpenAPI spec."""
import copy
import json
import logging
import re
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Union
import requests
import yaml
from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
from pydantic import ValidationError
logger = logging.getLogger(__name__)
class HTTPVerb(str, Enum):
"""HTTP verbs."""
GET = "get"
PUT = "put"
POST = "post"
DELETE = "delete"
OPTIONS = "options"
HEAD = "head"
PATCH = "patch"
TRACE = "trace"
@classmethod
def from_str(cls, verb: str) -> "HTTPVerb":
"""Parse an HTTP verb."""
try:
return cls(verb)
except ValueError:
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes misformatted parts of the spec."""
@property
def _paths_strict(self) -> Paths:
if not self.paths:
raise ValueError("No paths found in spec")
return self.paths
def _get_path_strict(self, path: str) -> PathItem:
path_item = self._paths_strict.get(path)
if not path_item:
raise ValueError(f"No path found for {path}")
return path_item
@property
def _components_strict(self) -> Components:
"""Get components or err."""
if self.components is None:
raise ValueError("No components found in spec. ")
return self.components
@property
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
"""Get parameters or err."""
parameters = self._components_strict.parameters
if parameters is None:
raise ValueError("No parameters found in spec. ")
return parameters
@property
def _schemas_strict(self) -> Dict[str, Schema]:
"""Get the dictionary of schemas or err."""
schemas = self._components_strict.schemas
if schemas is None:
raise ValueError("No schemas found in spec. ")
return schemas
@property
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
"""Get the request body or err."""
request_bodies = self._components_strict.requestBodies
if request_bodies is None:
raise ValueError("No request body found in spec. ")
return request_bodies
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
if ref_name not in parameters:
raise ValueError(f"No parameter found for {ref_name}")
return parameters[ref_name]
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err."""
parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference):
parameter = self._get_referenced_parameter(parameter)
return parameter
def get_referenced_schema(self, ref: Reference) -> Schema:
"""Get a schema (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
schemas = self._schemas_strict
if ref_name not in schemas:
raise ValueError(f"No schema found for {ref_name}")
return schemas[ref_name]
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err."""
schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference):
schema = self.get_referenced_schema(schema)
return schema
def _get_referenced_request_body(
self, ref: Reference
) -> Optional[Union[Reference, RequestBody]]:
"""Get a request body (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
request_bodies = self._request_bodies_strict
if ref_name not in request_bodies:
raise ValueError(f"No request body found for {ref_name}")
return request_bodies[ref_name]
def _get_root_referenced_request_body(
self, ref: Reference
) -> Optional[RequestBody]:
"""Get the root request Body or err."""
request_body = self._get_referenced_request_body(ref)
while isinstance(request_body, Reference):
request_body = self._get_referenced_request_body(request_body)
return request_body
@staticmethod
def _alert_unsupported_spec(obj: dict) -> None:
"""Alert if the spec is not supported."""
warning_message = (
" This may result in degraded performance."
+ " Convert your OpenAPI spec to 3.1.* spec"
+ " for better support."
)
swagger_version = obj.get("swagger")
openapi_version = obj.get("openapi")
if isinstance(openapi_version, str):
if openapi_version != "3.1.0":
logger.warning(
f"Attempting to load an OpenAPI {openapi_version}"
f" spec. {warning_message}"
)
else:
pass
elif isinstance(swagger_version, str):
logger.warning(
f"Attempting to load a Swagger {swagger_version}"
f" spec. {warning_message}"
)
else:
raise ValueError(
"Attempting to load an unsupported spec:"
f"\n\n{obj}\n{warning_message}"
)
@classmethod
def parse_obj(cls, obj: dict) -> "OpenAPISpec":
try:
cls._alert_unsupported_spec(obj)
return super().parse_obj(obj)
except ValidationError as e:
# We are handling possibly misconfigured specs and want to do a best-effort
# job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj)
for error in e.errors():
keys = error["loc"]
item = new_obj
for key in keys[:-1]:
item = item[key]
item.pop(keys[-1], None)
return cls.parse_obj(new_obj)
@classmethod
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec":
"""Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict)
@classmethod
def from_text(cls, text: str) -> "OpenAPISpec":
"""Get an OpenAPI spec from a text."""
try:
spec_dict = json.loads(text)
except json.JSONDecodeError:
spec_dict = yaml.safe_load(text)
return cls.from_spec_dict(spec_dict)
@classmethod
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec":
"""Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists():
raise FileNotFoundError(f"{path} does not exist")
with path_.open("r") as f:
return cls.from_text(f.read())
@classmethod
def from_url(cls, url: str) -> "OpenAPISpec":
"""Get an OpenAPI spec from a URL."""
response = requests.get(url)
return cls.from_text(response.text)
@property
def base_url(self) -> str:
"""Get the base url."""
return self.servers[0].url
def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path."""
path_item = self._get_path_strict(path)
results = []
for method in HTTPVerb:
operation = getattr(path_item, method.value, None)
if isinstance(operation, Operation):
results.append(method.value)
return results
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
if not isinstance(operation_obj, Operation):
raise ValueError(f"No {method} method found for {path}")
return operation_obj
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
parameters = []
if operation.parameters:
for parameter in operation.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
def get_request_body_for_operation(
self, operation: Operation
) -> Optional[RequestBody]:
"""Get the request body for a given operation."""
request_body = operation.requestBody
if isinstance(request_body, Reference):
request_body = self._get_root_referenced_request_body(request_body)
return request_body
@staticmethod
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
# Replace all punctuation of any kind with underscore
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec # noqa: F401

@ -0,0 +1,285 @@
"""Utility functions for parsing an OpenAPI spec."""
import copy
import json
import logging
import re
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Union
import requests
import yaml
from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
from pydantic import ValidationError
logger = logging.getLogger(__name__)
class HTTPVerb(str, Enum):
"""HTTP verbs."""
GET = "get"
PUT = "put"
POST = "post"
DELETE = "delete"
OPTIONS = "options"
HEAD = "head"
PATCH = "patch"
TRACE = "trace"
@classmethod
def from_str(cls, verb: str) -> "HTTPVerb":
"""Parse an HTTP verb."""
try:
return cls(verb)
except ValueError:
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes misformatted parts of the spec."""
@property
def _paths_strict(self) -> Paths:
if not self.paths:
raise ValueError("No paths found in spec")
return self.paths
def _get_path_strict(self, path: str) -> PathItem:
path_item = self._paths_strict.get(path)
if not path_item:
raise ValueError(f"No path found for {path}")
return path_item
@property
def _components_strict(self) -> Components:
"""Get components or err."""
if self.components is None:
raise ValueError("No components found in spec. ")
return self.components
@property
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
"""Get parameters or err."""
parameters = self._components_strict.parameters
if parameters is None:
raise ValueError("No parameters found in spec. ")
return parameters
@property
def _schemas_strict(self) -> Dict[str, Schema]:
"""Get the dictionary of schemas or err."""
schemas = self._components_strict.schemas
if schemas is None:
raise ValueError("No schemas found in spec. ")
return schemas
@property
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
"""Get the request body or err."""
request_bodies = self._components_strict.requestBodies
if request_bodies is None:
raise ValueError("No request body found in spec. ")
return request_bodies
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
if ref_name not in parameters:
raise ValueError(f"No parameter found for {ref_name}")
return parameters[ref_name]
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err."""
parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference):
parameter = self._get_referenced_parameter(parameter)
return parameter
def get_referenced_schema(self, ref: Reference) -> Schema:
"""Get a schema (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
schemas = self._schemas_strict
if ref_name not in schemas:
raise ValueError(f"No schema found for {ref_name}")
return schemas[ref_name]
def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
if isinstance(schema, Reference):
return self.get_referenced_schema(schema)
return schema
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err."""
schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference):
schema = self.get_referenced_schema(schema)
return schema
def _get_referenced_request_body(
self, ref: Reference
) -> Optional[Union[Reference, RequestBody]]:
"""Get a request body (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
request_bodies = self._request_bodies_strict
if ref_name not in request_bodies:
raise ValueError(f"No request body found for {ref_name}")
return request_bodies[ref_name]
def _get_root_referenced_request_body(
self, ref: Reference
) -> Optional[RequestBody]:
"""Get the root request Body or err."""
request_body = self._get_referenced_request_body(ref)
while isinstance(request_body, Reference):
request_body = self._get_referenced_request_body(request_body)
return request_body
@staticmethod
def _alert_unsupported_spec(obj: dict) -> None:
"""Alert if the spec is not supported."""
warning_message = (
" This may result in degraded performance."
+ " Convert your OpenAPI spec to 3.1.* spec"
+ " for better support."
)
swagger_version = obj.get("swagger")
openapi_version = obj.get("openapi")
if isinstance(openapi_version, str):
if openapi_version != "3.1.0":
logger.warning(
f"Attempting to load an OpenAPI {openapi_version}"
f" spec. {warning_message}"
)
else:
pass
elif isinstance(swagger_version, str):
logger.warning(
f"Attempting to load a Swagger {swagger_version}"
f" spec. {warning_message}"
)
else:
raise ValueError(
"Attempting to load an unsupported spec:"
f"\n\n{obj}\n{warning_message}"
)
@classmethod
def parse_obj(cls, obj: dict) -> "OpenAPISpec":
try:
cls._alert_unsupported_spec(obj)
return super().parse_obj(obj)
except ValidationError as e:
# We are handling possibly misconfigured specs and want to do a best-effort
# job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj)
for error in e.errors():
keys = error["loc"]
item = new_obj
for key in keys[:-1]:
item = item[key]
item.pop(keys[-1], None)
return cls.parse_obj(new_obj)
@classmethod
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec":
"""Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict)
@classmethod
def from_text(cls, text: str) -> "OpenAPISpec":
"""Get an OpenAPI spec from a text."""
try:
spec_dict = json.loads(text)
except json.JSONDecodeError:
spec_dict = yaml.safe_load(text)
return cls.from_spec_dict(spec_dict)
@classmethod
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec":
"""Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists():
raise FileNotFoundError(f"{path} does not exist")
with path_.open("r") as f:
return cls.from_text(f.read())
@classmethod
def from_url(cls, url: str) -> "OpenAPISpec":
"""Get an OpenAPI spec from a URL."""
response = requests.get(url)
return cls.from_text(response.text)
@property
def base_url(self) -> str:
"""Get the base url."""
return self.servers[0].url
def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path."""
path_item = self._get_path_strict(path)
results = []
for method in HTTPVerb:
operation = getattr(path_item, method.value, None)
if isinstance(operation, Operation):
results.append(method.value)
return results
def get_parameters_for_path(self, path: str) -> List[Parameter]:
path_item = self._get_path_strict(path)
parameters = []
if not path_item.parameters:
return []
for parameter in path_item.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
if not isinstance(operation_obj, Operation):
raise ValueError(f"No {method} method found for {path}")
return operation_obj
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
parameters = []
if operation.parameters:
for parameter in operation.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
def get_request_body_for_operation(
self, operation: Operation
) -> Optional[RequestBody]:
"""Get the request body for a given operation."""
request_body = operation.requestBody
if isinstance(request_body, Reference):
request_body = self._get_root_referenced_request_body(request_body)
return request_body
@staticmethod
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
# Replace all punctuation of any kind with underscore
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
Loading…
Cancel
Save