mirror of https://github.com/hwchase17/langchain
hierarchical planning agent for multi-step queries against larger openapi specs (#2170)
The specs used in chat-gpt plugins have only a few endpoints and have unrealistically small specifications. By contrast, a spec like spotify's has 60+ endpoints and is comprised 100k+ tokens. Here are some impressive traces from gpt-4 that string together non-trivial sequences of API calls. As noted in `planner.py`, gpt-3 is not as robust but can be improved with i) better retry, self-reflect, etc. logic and ii) better few-shots iii) etc. This PR's just a first attempt probing a few different directions that eventually can be made more core. `make me a playlist with songs from kind of blue. call it machine blues.` ``` > Entering new AgentExecutor chain... Action: api_planner Action Input: I need to find the right API calls to create a playlist with songs from Kind of Blue and name it Machine Blues Observation: 1. GET /search to find the album ID for "Kind of Blue". 2. GET /albums/{id}/tracks to get the tracks from the "Kind of Blue" album. 3. GET /me to get the current user's ID. 4. POST /users/{user_id}/playlists to create a new playlist named "Machine Blues" for the current user. 5. POST /playlists/{playlist_id}/tracks to add the tracks from "Kind of Blue" to the newly created "Machine Blues" playlist. Thought:I have a plan to create the playlist. Now, I will execute the API calls. Action: api_controller Action Input: 1. GET /search to find the album ID for "Kind of Blue". 2. GET /albums/{id}/tracks to get the tracks from the "Kind of Blue" album. 3. GET /me to get the current user's ID. 4. POST /users/{user_id}/playlists to create a new playlist named "Machine Blues" for the current user. 5. POST /playlists/{playlist_id}/tracks to add the tracks from "Kind of Blue" to the newly created "Machine Blues" playlist. > Entering new AgentExecutor chain... Action: requests_get Action Input: {"url": "https://api.spotify.com/v1/search?q=Kind%20of%20Blue&type=album", "output_instructions": "Extract the id of the first album in the search results"} Observation: 1weenld61qoidwYuZ1GESA Thought:Action: requests_get Action Input: {"url": "https://api.spotify.com/v1/albums/1weenld61qoidwYuZ1GESA/tracks", "output_instructions": "Extract the ids of all the tracks in the album"} Observation: ["7q3kkfAVpmcZ8g6JUThi3o"] Thought:Action: requests_get Action Input: {"url": "https://api.spotify.com/v1/me", "output_instructions": "Extract the id of the current user"} Observation: 22rhrz4m4kvpxlsb5hezokzwi Thought:Action: requests_post Action Input: {"url": "https://api.spotify.com/v1/users/22rhrz4m4kvpxlsb5hezokzwi/playlists", "data": {"name": "Machine Blues"}, "output_instructions": "Extract the id of the newly created playlist"} Observation: 48YP9TMcEtFu9aGN8n10lg Thought:Action: requests_post Action Input: {"url": "https://api.spotify.com/v1/playlists/48YP9TMcEtFu9aGN8n10lg/tracks", "data": {"uris": ["spotify:track:7q3kkfAVpmcZ8g6JUThi3o"]}, "output_instructions": "Confirm that the tracks were added to the playlist"} Observation: The tracks were added to the playlist. The snapshot_id is "Miw4NTdmMWUxOGU5YWMxMzVmYmE3ZWE5MWZlYWNkMTc2NGVmNTI1ZjY5". Thought:I am finished executing the plan. Final Answer: The tracks from the "Kind of Blue" album have been added to the newly created "Machine Blues" playlist. The playlist ID is 48YP9TMcEtFu9aGN8n10lg. > Finished chain. Observation: The tracks from the "Kind of Blue" album have been added to the newly created "Machine Blues" playlist. The playlist ID is 48YP9TMcEtFu9aGN8n10lg. Thought:I am finished executing the plan and have created the playlist with songs from Kind of Blue, named Machine Blues. Final Answer: I have created a playlist called "Machine Blues" with songs from the "Kind of Blue" album. The playlist ID is 48YP9TMcEtFu9aGN8n10lg. > Finished chain. ``` or `give me a song in the style of tobe nwige` ``` > Entering new AgentExecutor chain... Action: api_planner Action Input: I need to find the right API calls to get a song in the style of Tobe Nwigwe Observation: 1. GET /search to find the artist ID for Tobe Nwigwe. 2. GET /artists/{id}/related-artists to find similar artists to Tobe Nwigwe. 3. Pick one of the related artists and use their artist ID in the next step. 4. GET /artists/{id}/top-tracks to get the top tracks of the chosen related artist. Thought: I'm ready to execute the API calls. Action: api_controller Action Input: 1. GET /search to find the artist ID for Tobe Nwigwe. 2. GET /artists/{id}/related-artists to find similar artists to Tobe Nwigwe. 3. Pick one of the related artists and use their artist ID in the next step. 4. GET /artists/{id}/top-tracks to get the top tracks of the chosen related artist. > Entering new AgentExecutor chain... Action: requests_get Action Input: {"url": "https://api.spotify.com/v1/search?q=Tobe%20Nwigwe&type=artist", "output_instructions": "Extract the artist id for Tobe Nwigwe"} Observation: 3Qh89pgJeZq6d8uM1bTot3 Thought:Action: requests_get Action Input: {"url": "https://api.spotify.com/v1/artists/3Qh89pgJeZq6d8uM1bTot3/related-artists", "output_instructions": "Extract the ids and names of the related artists"} Observation: [ { "id": "75WcpJKWXBV3o3cfluWapK", "name": "Lute" }, { "id": "5REHfa3YDopGOzrxwTsPvH", "name": "Deante' Hitchcock" }, { "id": "6NL31G53xThQXkFs7lDpL5", "name": "Rapsody" }, { "id": "5MbNzCW3qokGyoo9giHA3V", "name": "EARTHGANG" }, { "id": "7Hjbimq43OgxaBRpFXic4x", "name": "Saba" }, { "id": "1ewyVtTZBqFYWIcepopRhp", "name": "Mick Jenkins" } ] Thought:Action: requests_get Action Input: {"url": "https://api.spotify.com/v1/artists/75WcpJKWXBV3o3cfluWapK/top-tracks?country=US", "output_instructions": "Extract the ids and names of the top tracks"} Observation: [ { "id": "6MF4tRr5lU8qok8IKaFOBE", "name": "Under The Sun (with J. Cole & Lute feat. DaBaby)" } ] Thought:I am finished executing the plan. Final Answer: The top track of the related artist Lute is "Under The Sun (with J. Cole & Lute feat. DaBaby)" with the track ID "6MF4tRr5lU8qok8IKaFOBE". > Finished chain. Observation: The top track of the related artist Lute is "Under The Sun (with J. Cole & Lute feat. DaBaby)" with the track ID "6MF4tRr5lU8qok8IKaFOBE". Thought:I am finished executing the plan and have the information the user asked for. Final Answer: The song "Under The Sun (with J. Cole & Lute feat. DaBaby)" by Lute is in the style of Tobe Nwigwe. > Finished chain. ``` --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>pull/2419/head
parent
d6d6f322a9
commit
b026a62bc4
@ -0,0 +1,215 @@
|
||||
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||
API_CONTROLLER_PROMPT,
|
||||
API_CONTROLLER_TOOL_DESCRIPTION,
|
||||
API_CONTROLLER_TOOL_NAME,
|
||||
API_ORCHESTRATOR_PROMPT,
|
||||
API_PLANNER_PROMPT,
|
||||
API_PLANNER_TOOL_DESCRIPTION,
|
||||
API_PLANNER_TOOL_NAME,
|
||||
PARSING_GET_PROMPT,
|
||||
PARSING_POST_PROMPT,
|
||||
REQUESTS_GET_TOOL_DESCRIPTION,
|
||||
REQUESTS_POST_TOOL_DESCRIPTION,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.openapi.spec import (
|
||||
ReducedOpenAPISpec,
|
||||
)
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.requests.tool import BaseRequestsTool
|
||||
|
||||
#
|
||||
# Requests tools with LLM-instructed extraction of truncated responses.
|
||||
#
|
||||
# Of course, truncating so bluntly may lose a lot of valuable
|
||||
# information in the response.
|
||||
# However, the goal for now is to have only a single inference step.
|
||||
MAX_RESPONSE_LENGTH = 5000
|
||||
|
||||
|
||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
name = "requests_get"
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
llm_chain = LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=PARSING_GET_PROMPT,
|
||||
)
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.get(data["url"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
name = "requests_post"
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
llm_chain = LLMChain(
|
||||
llm=OpenAI(),
|
||||
prompt=PARSING_POST_PROMPT,
|
||||
)
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
response = self.requests_wrapper.post(data["url"], data["data"])
|
||||
response = response[: self.response_length]
|
||||
return self.llm_chain.predict(
|
||||
response=response, instructions=data["output_instructions"]
|
||||
).strip()
|
||||
|
||||
async def _arun(self, text: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
#
|
||||
# Orchestrator, planner, controller.
|
||||
#
|
||||
def _create_api_planner_tool(
|
||||
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
|
||||
) -> Tool:
|
||||
endpoint_descriptions = [
|
||||
f"{name} {description}" for name, description, _ in api_spec.endpoints
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_PLANNER_PROMPT,
|
||||
input_variables=["query"],
|
||||
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
tool = Tool(
|
||||
name=API_PLANNER_TOOL_NAME,
|
||||
description=API_PLANNER_TOOL_DESCRIPTION,
|
||||
func=chain.run,
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
def _create_api_controller_agent(
|
||||
api_url: str,
|
||||
api_docs: str,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> AgentExecutor:
|
||||
tools: List[BaseTool] = [
|
||||
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
|
||||
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_CONTROLLER_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
partial_variables={
|
||||
"api_url": api_url,
|
||||
"api_docs": api_docs,
|
||||
"tool_names": ", ".join([tool.name for tool in tools]),
|
||||
"tool_descriptions": "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in tools]
|
||||
),
|
||||
},
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(llm=llm, prompt=prompt),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
|
||||
def _create_api_controller_tool(
|
||||
api_spec: ReducedOpenAPISpec,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> Tool:
|
||||
"""Expose controller as a tool.
|
||||
|
||||
The tool is invoked with a plan from the planner, and dynamically
|
||||
creates a controller agent with relevant documentation only to
|
||||
constrain the context.
|
||||
"""
|
||||
|
||||
base_url = api_spec.servers[0]["url"] # TODO: do better.
|
||||
|
||||
def _create_and_run_api_controller_agent(plan_str: str) -> str:
|
||||
pattern = r"\b(GET|POST)\s+(/\S+)*"
|
||||
matches = re.findall(pattern, plan_str)
|
||||
endpoint_names = [
|
||||
"{method} {route}".format(method=method, route=route.split("?")[0])
|
||||
for method, route in matches
|
||||
]
|
||||
endpoint_docs_by_name = {name: docs for name, _, docs in api_spec.endpoints}
|
||||
docs_str = ""
|
||||
for endpoint_name in endpoint_names:
|
||||
docs = endpoint_docs_by_name.get(endpoint_name)
|
||||
if not docs:
|
||||
raise ValueError(f"{endpoint_name} endpoint does not exist.")
|
||||
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
|
||||
|
||||
agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
|
||||
return agent.run(plan_str)
|
||||
|
||||
return Tool(
|
||||
name=API_CONTROLLER_TOOL_NAME,
|
||||
func=_create_and_run_api_controller_agent,
|
||||
description=API_CONTROLLER_TOOL_DESCRIPTION,
|
||||
)
|
||||
|
||||
|
||||
def create_openapi_agent(
|
||||
api_spec: ReducedOpenAPISpec,
|
||||
requests_wrapper: RequestsWrapper,
|
||||
llm: BaseLanguageModel,
|
||||
) -> AgentExecutor:
|
||||
"""Instantiate API planner and controller for a given spec.
|
||||
|
||||
Inject credentials via requests_wrapper.
|
||||
|
||||
We use a top-level "orchestrator" agent to invoke the planner and controller,
|
||||
rather than a top-level planner
|
||||
that invokes a controller with its plan. This is to keep the planner simple.
|
||||
"""
|
||||
tools = [
|
||||
_create_api_planner_tool(api_spec, llm),
|
||||
_create_api_controller_tool(api_spec, requests_wrapper, llm),
|
||||
]
|
||||
prompt = PromptTemplate(
|
||||
template=API_ORCHESTRATOR_PROMPT,
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
partial_variables={
|
||||
"tool_names": ", ".join([tool.name for tool in tools]),
|
||||
"tool_descriptions": "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in tools]
|
||||
),
|
||||
},
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=LLMChain(llm=llm, prompt=prompt),
|
||||
allowed_tools=[tool.name for tool in tools],
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
@ -0,0 +1,156 @@
|
||||
# flake8: noqa
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
|
||||
|
||||
You should:
|
||||
1) evaluate whether the user query can be solved by the API documentated below. If no, say why.
|
||||
2) if yes, generate a plan of API calls and say what they are doing step by step.
|
||||
|
||||
You should only use API endpoints documented below ("Endpoints you can use:").
|
||||
Some user queries can be resolved in a single API call, but some will require several API calls.
|
||||
The plan will be passed to an API controller that can format it into web requests and return the responses.
|
||||
|
||||
----
|
||||
|
||||
Here are some examples:
|
||||
|
||||
Fake endpoints for examples:
|
||||
GET /user to get information about the current user
|
||||
GET /products/search search across products
|
||||
POST /users/{{id}}/cart to add products to a user's cart
|
||||
|
||||
User query: tell me a joke
|
||||
Plan: Sorry, this API's domain is shopping, not comedy.
|
||||
|
||||
Usery query: I want to buy a couch
|
||||
Plan: 1. GET /products/search to search for couches
|
||||
2. GET /user to find the user's id
|
||||
3. POST /users/{{id}}/cart to add a couch to the user's cart
|
||||
|
||||
----
|
||||
|
||||
Here are endpoints you can use. Do not reference any of the endpoints above.
|
||||
|
||||
{endpoints}
|
||||
|
||||
----
|
||||
|
||||
User query: {query}
|
||||
Plan:"""
|
||||
API_PLANNER_TOOL_NAME = "api_planner"
|
||||
API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to calling the API controller."
|
||||
|
||||
# Execution.
|
||||
API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
|
||||
If you cannot complete them and run into issues, you should explain the issue. If you're able to resolve an API call call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.
|
||||
|
||||
|
||||
Here is documentation on the API:
|
||||
Base url: {api_url}
|
||||
Endpoints:
|
||||
{api_docs}
|
||||
|
||||
|
||||
Here are tools to execute requests against the API: {tool_descriptions}
|
||||
|
||||
|
||||
Starting below, you should follow this format:
|
||||
|
||||
Plan: the plan of API calls to execute
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of the tools [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the output of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
|
||||
Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.
|
||||
|
||||
|
||||
Begin!
|
||||
|
||||
Plan: {input}
|
||||
Thought:
|
||||
{agent_scratchpad}
|
||||
"""
|
||||
API_CONTROLLER_TOOL_NAME = "api_controller"
|
||||
API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)."
|
||||
|
||||
# Orchestrate planning + execution.
|
||||
# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while
|
||||
# keeping planning (and specifically the planning prompt) simple.
|
||||
API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources.
|
||||
Some user queries can be resolved in a single API call though some require several API call.
|
||||
You should always plan your API calls first, and then execute the plan second.
|
||||
You should never return information without executing the api_controller tool.
|
||||
|
||||
|
||||
Here are the tools to plan and execute API requests: {tool_descriptions}
|
||||
|
||||
|
||||
Starting below, you should follow this format:
|
||||
|
||||
User query: the query a User wants help with related to the API
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of the tools [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I am finished executing a plan and have the information the user asked for or the data the used asked to create
|
||||
Final Answer: the final output from executing the plan
|
||||
|
||||
|
||||
Example:
|
||||
User query: can you add some trendy stuff to my shopping cart.
|
||||
Thought: I should plan API calls first.
|
||||
Action: api_planner
|
||||
Action Input: I need to find the right API calls to add trendy items to the users shopping cart
|
||||
Observation: 1) GET /items/trending to get trending item ids
|
||||
2) GET /user to get user
|
||||
3) POST /cart to post the trending items to the user's cart
|
||||
Thought: I'm ready to execute the API calls.
|
||||
Action: api_controller
|
||||
Action Input: 1) GET /items/trending to get trending item ids
|
||||
2) GET /user to get user
|
||||
3) POST /cart to post the trending items to the user's cart
|
||||
...
|
||||
|
||||
Begin!
|
||||
|
||||
User query: {input}
|
||||
Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller.
|
||||
{agent_scratchpad}"""
|
||||
|
||||
REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website.
|
||||
Input to the tool should be a json string with 2 keys: "url" and "output_instructions".
|
||||
The value of "url" should be a string. The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the GET request fetches.
|
||||
"""
|
||||
|
||||
PARSING_GET_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
|
||||
REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website.
|
||||
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
|
||||
The value of "url" should be a string.
|
||||
The value of "data" should be a dictionary of key-value pairs you want to POST to the url.
|
||||
The value of "summary_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates.
|
||||
Always use double quotes for strings in the json string."""
|
||||
|
||||
PARSING_POST_PROMPT = PromptTemplate(
|
||||
template="""Here is an API response:\n\n{response}\n\n====
|
||||
Your task is to extract some information according to these instructions: {instructions}
|
||||
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
|
||||
If the response indicates an error, you should instead output a summary of the error.
|
||||
|
||||
Output:""",
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
@ -0,0 +1,110 @@
|
||||
"""Quick and dirty representation for OpenAPI specs."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
|
||||
def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
|
||||
"""Try to substitute $refs.
|
||||
|
||||
The goal is to get the complete docs for each endpoint in context for now.
|
||||
|
||||
In the few OpenAPI specs I studied, $refs referenced models
|
||||
(or in OpenAPI terms, components) and could be nested. This code most
|
||||
likely misses lots of cases.
|
||||
"""
|
||||
|
||||
def _retrieve_ref_path(path: str, full_spec: dict) -> dict:
|
||||
components = path.split("/")
|
||||
if components[0] != "#":
|
||||
raise RuntimeError(
|
||||
"All $refs I've seen so far are uri fragments (start with hash)."
|
||||
)
|
||||
out = full_spec
|
||||
for component in components[1:]:
|
||||
out = out[component]
|
||||
return out
|
||||
|
||||
def _dereference_refs(
|
||||
obj: Union[dict, list], stop: bool = False
|
||||
) -> Union[dict, list]:
|
||||
if stop:
|
||||
return obj
|
||||
obj_out: Dict[str, Any] = {}
|
||||
if isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
if k == "$ref":
|
||||
# stop=True => don't dereference recursively.
|
||||
return _dereference_refs(
|
||||
_retrieve_ref_path(v, full_spec), stop=True
|
||||
)
|
||||
elif isinstance(v, list):
|
||||
obj_out[k] = [_dereference_refs(el) for el in v]
|
||||
elif isinstance(v, dict):
|
||||
obj_out[k] = _dereference_refs(v)
|
||||
else:
|
||||
obj_out[k] = v
|
||||
return obj_out
|
||||
elif isinstance(obj, list):
|
||||
return [_dereference_refs(el) for el in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
return _dereference_refs(spec_obj)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReducedOpenAPISpec:
|
||||
servers: List[dict]
|
||||
description: str
|
||||
endpoints: List[Tuple[str, str, dict]]
|
||||
|
||||
|
||||
def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec:
|
||||
"""Simplify/distill/minify a spec somehow.
|
||||
|
||||
I want a smaller target for retrieval and (more importantly)
|
||||
I want smaller results from retrieval.
|
||||
I was hoping https://openapi.tools/ would have some useful bits
|
||||
to this end, but doesn't seem so.
|
||||
"""
|
||||
# 1. Consider only get, post endpoints.
|
||||
endpoints = [
|
||||
(f"{operation_name.upper()} {route}", docs.get("description"), docs)
|
||||
for route, operation in spec["paths"].items()
|
||||
for operation_name, docs in operation.items()
|
||||
if operation_name in ["get", "post"]
|
||||
]
|
||||
|
||||
# 2. Replace any refs so that complete docs are retrieved.
|
||||
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
|
||||
if dereference:
|
||||
endpoints = [
|
||||
(name, description, dereference_refs(docs, spec))
|
||||
for name, description, docs in endpoints
|
||||
]
|
||||
|
||||
# 3. Strip docs down to required request args + happy path response.
|
||||
def reduce_endpoint_docs(docs: dict) -> dict:
|
||||
out = {}
|
||||
if docs.get("description"):
|
||||
out["description"] = docs.get("description")
|
||||
if docs.get("parameters"):
|
||||
out["parameters"] = [
|
||||
parameter
|
||||
for parameter in docs.get("parameters", [])
|
||||
if parameter.get("required")
|
||||
]
|
||||
if "200" in docs["responses"]:
|
||||
out["responses"] = docs["responses"]["200"]
|
||||
return out
|
||||
|
||||
endpoints = [
|
||||
(name, description, reduce_endpoint_docs(docs))
|
||||
for name, description, docs in endpoints
|
||||
]
|
||||
return ReducedOpenAPISpec(
|
||||
servers=spec["servers"],
|
||||
description=spec["info"].get("description", ""),
|
||||
endpoints=endpoints,
|
||||
)
|
Loading…
Reference in New Issue