You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
6.9 KiB

"""Utility functions for parsing an OpenAPI spec."""
import copy
import json
import logging
import re
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Union
import requests
import yaml
from openapi_schema_pydantic import (
from pydantic import ValidationError
logger = logging.getLogger(__name__)
class HTTPVerb(str, Enum):
"""HTTP verbs."""
GET = "get"
PUT = "put"
POST = "post"
DELETE = "delete"
OPTIONS = "options"
HEAD = "head"
PATCH = "patch"
TRACE = "trace"
def from_str(cls, verb: str) -> "HTTPVerb":
"""Parse an HTTP verb."""
return cls(verb)
except ValueError:
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes misformatted parts of the spec."""
def _paths_strict(self) -> Paths:
if not self.paths:
raise ValueError("No paths found in spec")
return self.paths
def _get_path_strict(self, path: str) -> PathItem:
path_item = self._paths_strict.get(path)
if not path_item:
raise ValueError(f"No path found for {path}")
return path_item
def _components_strict(self) -> Components:
"""Get components or err."""
if self.components is None:
raise ValueError("No components found in spec. ")
return self.components
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
"""Get parameters or err."""
parameters = self._components_strict.parameters
if parameters is None:
raise ValueError("No parameters found in spec. ")
return parameters
def _schemas_strict(self) -> Dict[str, Schema]:
"""Get the dictionary of schemas or err."""
schemas = self._components_strict.schemas
if schemas is None:
raise ValueError("No schemas found in spec. ")
return schemas
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
if ref_name not in parameters:
raise ValueError(f"No parameter found for {ref_name}")
return parameters[ref_name]
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err."""
parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference):
parameter = self._get_referenced_parameter(parameter)
return parameter
def get_referenced_schema(self, ref: Reference) -> Schema:
"""Get a schema (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
schemas = self._schemas_strict
if ref_name not in schemas:
raise ValueError(f"No schema found for {ref_name}")
return schemas[ref_name]
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err."""
schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference):
schema = self.get_referenced_schema(schema)
return schema
def parse_obj(cls, obj: Any) -> "OpenAPISpec":
return super().parse_obj(obj)
except ValidationError as e:
# We are handling possibly misconfigured specs and want to do a best-effort
# job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj)
for error in e.errors():
keys = error["loc"]
item = new_obj
for key in keys[:-1]:
item = item[key]
item.pop(keys[-1], None)
return cls.parse_obj(new_obj)
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec":
"""Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict)
def from_text(cls, text: str) -> "OpenAPISpec":
"""Get an OpenAPI spec from a text."""
spec_dict = json.loads(text)
except json.JSONDecodeError:
spec_dict = yaml.safe_load(text)
return cls.from_spec_dict(spec_dict)
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec":
"""Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists():
raise FileNotFoundError(f"{path} does not exist")
with"r") as f:
return cls.from_text(
def from_url(cls, url: str) -> "OpenAPISpec":
"""Get an OpenAPI spec from a URL."""
response = requests.get(url)
return cls.from_text(response.text)
def base_url(self) -> str:
"""Get the base url."""
return self.servers[0].url
def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path."""
path_item = self._get_path_strict(path)
results = []
for method in HTTPVerb:
operation = getattr(path_item, method.value, None)
if isinstance(operation, Operation):
return results
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
if not isinstance(operation_obj, Operation):
raise ValueError(f"No {method} method found for {path}")
return operation_obj
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
parameters = []
if operation.parameters:
for parameter in operation.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
return parameters
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
# Replace all punctuation of any kind with underscore
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")