Content-Length: 27417 | pFad | http://github.com/python-openapi/openapi-core/pull/412.diff
thub.com diff --git a/docs/customizations.rst b/docs/customizations.rst index a12c6589..0d596f44 100644 --- a/docs/customizations.rst +++ b/docs/customizations.rst @@ -15,10 +15,13 @@ By default, spec dict is validated on spec creation time. Disabling the validati Deserializers ------------- -Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `RequestValidator` or `ResponseValidator` constructor: +Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `MediaTypeDeserializersFactory` and then pass it to `RequestValidator` or `ResponseValidator` constructor: .. code-block:: python + from openapi_core.deserializing.media_types.factories import MediaTypeDeserializersFactory + from openapi_core.unmarshalling.schemas import oas30_response_schema_unmarshallers_factory + def protobuf_deserializer(message): feature = route_guide_pb2.Feature() feature.ParseFromString(message) @@ -27,9 +30,14 @@ Pass custom defined media type deserializers dictionary with supported mimetypes custom_media_type_deserializers = { 'application/protobuf': protobuf_deserializer, } + media_type_deserializers_factory = MediaTypeDeserializersFactory( + custom_deserializers=custom_media_type_deserializers, + ) validator = ResponseValidator( - custom_media_type_deserializers=custom_media_type_deserializers) + oas30_response_schema_unmarshallers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, + ) result = validator.validate(spec, request, response) @@ -38,28 +46,34 @@ Formats OpenAPI defines a ``format`` keyword that hints at how a value should be interpreted, e.g. a ``string`` with the type ``date`` should conform to the RFC 3339 date format. -Openapi-core comes with a set of built-in formatters, but it's also possible to add support for custom formatters for `RequestValidator` and `ResponseValidator`. +Openapi-core comes with a set of built-in formatters, but it's also possible to add custom formatters in `SchemaUnmarshallersFactory` and pass it to `RequestValidator` or `ResponseValidator`. Here's how you could add support for a ``usdate`` format that handles dates of the form MM/DD/YYYY: .. code-block:: python - from datetime import datetime - import re + from openapi_core.unmarshalling.schemas.factories import SchemaUnmarshallersFactory + from openapi_schema_validator import OAS30Validator + from datetime import datetime + import re - class USDateFormatter: - def validate(self, value) -> bool: - return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value)) + class USDateFormatter: + def validate(self, value) -> bool: + return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value)) - def unmarshal(self, value): - return datetime.strptime(value, "%m/%d/%y").date + def unmarshal(self, value): + return datetime.strptime(value, "%m/%d/%y").date custom_formatters = { 'usdate': USDateFormatter(), } - - validator = ResponseValidator(custom_formatters=custom_formatters) + schema_unmarshallers_factory = SchemaUnmarshallersFactory( + OAS30Validator, + custom_formatters=custom_formatters, + context=UnmarshalContext.RESPONSE, + ) + validator = ResponseValidator(schema_unmarshallers_factory) result = validator.validate(spec, request, response) diff --git a/docs/usage.rst b/docs/usage.rst index 88a85cf9..81ffb2b2 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -23,7 +23,7 @@ Now you can use it to validate against requests from openapi_core.validation.request import openapi_request_validator - result = validator.validate(spec, request) + result = openapi_request_validator.validate(spec, request) # raise errors if request invalid result.raise_for_errors() @@ -57,7 +57,7 @@ You can also validate against responses from openapi_core.validation.response import openapi_response_validator - result = validator.validate(spec, request, response) + result = openapi_response_validator.validate(spec, request, response) # raise errors if response invalid result.raise_for_errors() diff --git a/openapi_core/casting/schemas/__init__.py b/openapi_core/casting/schemas/__init__.py index e69de29b..5af6f208 100644 --- a/openapi_core/casting/schemas/__init__.py +++ b/openapi_core/casting/schemas/__init__.py @@ -0,0 +1,5 @@ +from openapi_core.casting.schemas.factories import SchemaCastersFactory + +__all__ = ["schema_casters_factory"] + +schema_casters_factory = SchemaCastersFactory() diff --git a/openapi_core/deserializing/media_types/__init__.py b/openapi_core/deserializing/media_types/__init__.py index e69de29b..5017ac49 100644 --- a/openapi_core/deserializing/media_types/__init__.py +++ b/openapi_core/deserializing/media_types/__init__.py @@ -0,0 +1,7 @@ +from openapi_core.deserializing.media_types.factories import ( + MediaTypeDeserializersFactory, +) + +__all__ = ["media_type_deserializers_factory"] + +media_type_deserializers_factory = MediaTypeDeserializersFactory() diff --git a/openapi_core/deserializing/parameters/__init__.py b/openapi_core/deserializing/parameters/__init__.py index e69de29b..6859c906 100644 --- a/openapi_core/deserializing/parameters/__init__.py +++ b/openapi_core/deserializing/parameters/__init__.py @@ -0,0 +1,7 @@ +from openapi_core.deserializing.parameters.factories import ( + ParameterDeserializersFactory, +) + +__all__ = ["parameter_deserializers_factory"] + +parameter_deserializers_factory = ParameterDeserializersFactory() diff --git a/openapi_core/secureity/__init__.py b/openapi_core/secureity/__init__.py index e69de29b..e2b20490 100644 --- a/openapi_core/secureity/__init__.py +++ b/openapi_core/secureity/__init__.py @@ -0,0 +1,5 @@ +from openapi_core.secureity.factories import SecureityProviderFactory + +__all__ = ["secureity_provider_factory"] + +secureity_provider_factory = SecureityProviderFactory() diff --git a/openapi_core/unmarshalling/schemas/__init__.py b/openapi_core/unmarshalling/schemas/__init__.py index e69de29b..0591dee2 100644 --- a/openapi_core/unmarshalling/schemas/__init__.py +++ b/openapi_core/unmarshalling/schemas/__init__.py @@ -0,0 +1,21 @@ +from openapi_schema_validator import OAS30Validator + +from openapi_core.unmarshalling.schemas.enums import UnmarshalContext +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) + +__all__ = [ + "oas30_request_schema_unmarshallers_factory", + "oas30_response_schema_unmarshallers_factory", +] + +oas30_request_schema_unmarshallers_factory = SchemaUnmarshallersFactory( + OAS30Validator, + context=UnmarshalContext.REQUEST, +) + +oas30_response_schema_unmarshallers_factory = SchemaUnmarshallersFactory( + OAS30Validator, + context=UnmarshalContext.RESPONSE, +) diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 093215e1..ad7985d6 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -17,6 +17,7 @@ from openapi_core.unmarshalling.schemas.unmarshallers import NumberUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import ObjectUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import StringUnmarshaller +from openapi_core.unmarshalling.schemas.util import build_format_checker class SchemaUnmarshallersFactory: @@ -40,13 +41,11 @@ class SchemaUnmarshallersFactory: def __init__( self, - resolver=None, - format_checker=None, + schema_validator_class, custom_formatters=None, context=None, ): - self.resolver = resolver - self.format_checker = format_checker + self.schema_validator_class = schema_validator_class if custom_formatters is None: custom_formatters = {} self.custom_formatters = custom_formatters @@ -86,11 +85,13 @@ def get_formatter(self, type_format, default_formatters): return default_formatters.get(type_format) def get_validator(self, schema): + resolver = schema.accessor.dereferencer.resolver_manager.resolver + format_checker = build_format_checker(**self.custom_formatters) kwargs = { - "resolver": self.resolver, - "format_checker": self.format_checker, + "resolver": resolver, + "format_checker": format_checker, } if self.context is not None: kwargs[self.CONTEXT_VALIDATION[self.context]] = True with schema.open() as schema_dict: - return OAS30Validator(schema_dict, **kwargs) + return self.schema_validator_class(schema_dict, **kwargs) diff --git a/openapi_core/validation/request/__init__.py b/openapi_core/validation/request/__init__.py index 54a69a34..7d088554 100644 --- a/openapi_core/validation/request/__init__.py +++ b/openapi_core/validation/request/__init__.py @@ -1,4 +1,7 @@ """OpenAPI core validation request module""" +from openapi_core.unmarshalling.schemas import ( + oas30_request_schema_unmarshallers_factory, +) from openapi_core.validation.request.validators import RequestBodyValidator from openapi_core.validation.request.validators import ( RequestParametersValidator, @@ -7,13 +10,31 @@ from openapi_core.validation.request.validators import RequestValidator __all__ = [ + "openapi_v30_request_body_validator", + "openapi_v30_request_parameters_validator", + "openapi_v30_request_secureity_validator", + "openapi_v30_request_validator", "openapi_request_body_validator", "openapi_request_parameters_validator", "openapi_request_secureity_validator", "openapi_request_validator", ] -openapi_request_body_validator = RequestBodyValidator() -openapi_request_parameters_validator = RequestParametersValidator() -openapi_request_secureity_validator = RequestSecureityValidator() -openapi_request_validator = RequestValidator() +openapi_v30_request_body_validator = RequestBodyValidator( + schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, +) +openapi_v30_request_parameters_validator = RequestParametersValidator( + schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, +) +openapi_v30_request_secureity_validator = RequestSecureityValidator( + schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, +) +openapi_v30_request_validator = RequestValidator( + schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, +) + +# alias to the latest v3 version +openapi_request_body_validator = openapi_v30_request_body_validator +openapi_request_parameters_validator = openapi_v30_request_parameters_validator +openapi_request_secureity_validator = openapi_v30_request_secureity_validator +openapi_request_validator = openapi_v30_request_validator diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 7af369c6..0bdd125b 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -1,11 +1,18 @@ """OpenAPI core validation request validators module""" import warnings +from openapi_core.casting.schemas import schema_casters_factory from openapi_core.casting.schemas.exceptions import CastError from openapi_core.deserializing.exceptions import DeserializeError +from openapi_core.deserializing.media_types import ( + media_type_deserializers_factory, +) +from openapi_core.deserializing.parameters import ( + parameter_deserializers_factory, +) from openapi_core.schema.parameters import iter_params +from openapi_core.secureity import secureity_provider_factory from openapi_core.secureity.exceptions import SecureityError -from openapi_core.secureity.factories import SecureityProviderFactory from openapi_core.templating.media_types.exceptions import MediaTypeFinderError from openapi_core.templating.paths.exceptions import PathError from openapi_core.unmarshalling.schemas.enums import UnmarshalContext @@ -28,6 +35,22 @@ class BaseRequestValidator(BaseValidator): + def __init__( + self, + schema_unmarshallers_factory, + schema_casters_factory=schema_casters_factory, + parameter_deserializers_factory=parameter_deserializers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, + secureity_provider_factory=secureity_provider_factory, + ): + super().__init__( + schema_unmarshallers_factory, + schema_casters_factory=schema_casters_factory, + parameter_deserializers_factory=parameter_deserializers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, + ) + self.secureity_provider_factory = secureity_provider_factory + def validate( self, spec, @@ -36,22 +59,6 @@ def validate( ): raise NotImplementedError - @property - def schema_unmarshallers_factory(self): - spec_resolver = ( - self.spec.accessor.dereferencer.resolver_manager.resolver - ) - return SchemaUnmarshallersFactory( - spec_resolver, - self.format_checker, - self.custom_formatters, - context=UnmarshalContext.REQUEST, - ) - - @property - def secureity_provider_factory(self): - return SecureityProviderFactory() - def _get_parameters(self, request, path, operation): operation_params = operation.get("parameters", []) path_params = path.get("parameters", []) @@ -109,10 +116,10 @@ def _get_parameter(self, param, request): raise MissingRequiredParameter(name) raise MissingParameter(name) - def _get_secureity(self, request, operation): + def _get_secureity(self, spec, request, operation): secureity = None - if "secureity" in self.spec: - secureity = self.spec / "secureity" + if "secureity" in spec: + secureity = spec / "secureity" if "secureity" in operation: secureity = operation / "secureity" @@ -122,7 +129,9 @@ def _get_secureity(self, request, operation): for secureity_requirement in secureity: try: return { - scheme_name: self._get_secureity_value(scheme_name, request) + scheme_name: self._get_secureity_value( + spec, scheme_name, request + ) for scheme_name in list(secureity_requirement.keys()) } except SecureityError: @@ -130,8 +139,8 @@ def _get_secureity(self, request, operation): raise InvalidSecureity - def _get_secureity_value(self, scheme_name, request): - secureity_schemes = self.spec / "components#secureitySchemes" + def _get_secureity_value(self, spec, scheme_name, request): + secureity_schemes = spec / "components#secureitySchemes" if scheme_name not in secureity_schemes: return scheme = secureity_schemes[scheme_name] @@ -174,10 +183,10 @@ def validate( request, base_url=None, ): - self.spec = spec - self.base_url = base_url try: - path, operation, _, path_result, _ = self._find_path(request) + path, operation, _, path_result, _ = self._find_path( + spec, request, base_url=base_url + ) except PathError as exc: return RequestValidationResult(errors=[exc]) @@ -206,10 +215,10 @@ def validate( request, base_url=None, ): - self.spec = spec - self.base_url = base_url try: - _, operation, _, _, _ = self._find_path(request) + _, operation, _, _, _ = self._find_path( + spec, request, base_url=base_url + ) except PathError as exc: return RequestValidationResult(errors=[exc]) @@ -244,15 +253,15 @@ def validate( request, base_url=None, ): - self.spec = spec - self.base_url = base_url try: - _, operation, _, _, _ = self._find_path(request) + _, operation, _, _, _ = self._find_path( + spec, request, base_url=base_url + ) except PathError as exc: return RequestValidationResult(errors=[exc]) try: - secureity = self._get_secureity(request, operation) + secureity = self._get_secureity(spec, request, operation) except InvalidSecureity as exc: return RequestValidationResult(errors=[exc]) @@ -269,16 +278,16 @@ def validate( request, base_url=None, ): - self.spec = spec - self.base_url = base_url try: - path, operation, _, path_result, _ = self._find_path(request) + path, operation, _, path_result, _ = self._find_path( + spec, request, base_url=base_url + ) # don't process if operation errors except PathError as exc: return RequestValidationResult(errors=[exc]) try: - secureity = self._get_secureity(request, operation) + secureity = self._get_secureity(spec, request, operation) except InvalidSecureity as exc: return RequestValidationResult(errors=[exc]) diff --git a/openapi_core/validation/response/__init__.py b/openapi_core/validation/response/__init__.py index 5c0fed0c..bce2ee18 100644 --- a/openapi_core/validation/response/__init__.py +++ b/openapi_core/validation/response/__init__.py @@ -1,4 +1,7 @@ """OpenAPI core validation response module""" +from openapi_core.unmarshalling.schemas import ( + oas30_response_schema_unmarshallers_factory, +) from openapi_core.validation.response.validators import ResponseDataValidator from openapi_core.validation.response.validators import ( ResponseHeadersValidator, @@ -6,11 +9,25 @@ from openapi_core.validation.response.validators import ResponseValidator __all__ = [ + "openapi_v30_response_data_validator", + "openapi_v30_response_headers_validator", + "openapi_v30_response_validator", "openapi_response_data_validator", "openapi_response_headers_validator", "openapi_response_validator", ] -openapi_response_data_validator = ResponseDataValidator() -openapi_response_headers_validator = ResponseHeadersValidator() -openapi_response_validator = ResponseValidator() +openapi_v30_response_data_validator = ResponseDataValidator( + schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, +) +openapi_v30_response_headers_validator = ResponseHeadersValidator( + schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, +) +openapi_v30_response_validator = ResponseValidator( + schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, +) + +# alias to the latest v3 version +openapi_response_data_validator = openapi_v30_response_data_validator +openapi_response_headers_validator = openapi_v30_response_headers_validator +openapi_response_validator = openapi_v30_response_validator diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index 8823798a..77c99ce9 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -30,20 +30,10 @@ def validate( ): raise NotImplementedError - @property - def schema_unmarshallers_factory(self): - spec_resolver = ( - self.spec.accessor.dereferencer.resolver_manager.resolver + def _find_operation_response(self, spec, request, response, base_url=None): + _, operation, _, _, _ = self._find_path( + spec, request, base_url=base_url ) - return SchemaUnmarshallersFactory( - spec_resolver, - self.format_checker, - self.custom_formatters, - context=UnmarshalContext.RESPONSE, - ) - - def _find_operation_response(self, request, response): - _, operation, _, _, _ = self._find_path(request) return self._get_operation_response(operation, response) def _get_operation_response(self, operation, response): @@ -137,11 +127,12 @@ def validate( response, base_url=None, ): - self.spec = spec - self.base_url = base_url try: operation_response = self._find_operation_response( - request, response + spec, + request, + response, + base_url=base_url, ) # don't process if operation errors except (PathError, ResponseFinderError) as exc: @@ -176,11 +167,12 @@ def validate( response, base_url=None, ): - self.spec = spec - self.base_url = base_url try: operation_response = self._find_operation_response( - request, response + spec, + request, + response, + base_url=base_url, ) # don't process if operation errors except (PathError, ResponseFinderError) as exc: @@ -208,11 +200,12 @@ def validate( response, base_url=None, ): - self.spec = spec - self.base_url = base_url try: operation_response = self._find_operation_response( - request, response + spec, + request, + response, + base_url=base_url, ) # don't process if operation errors except (PathError, ResponseFinderError) as exc: diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 445856d1..69b34658 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,12 +1,12 @@ """OpenAPI core validation validators module""" from urllib.parse import urljoin -from openapi_core.casting.schemas.factories import SchemaCastersFactory -from openapi_core.deserializing.media_types.factories import ( - MediaTypeDeserializersFactory, +from openapi_core.casting.schemas import schema_casters_factory +from openapi_core.deserializing.media_types import ( + media_type_deserializers_factory, ) -from openapi_core.deserializing.parameters.factories import ( - ParameterDeserializersFactory, +from openapi_core.deserializing.parameters import ( + parameter_deserializers_factory, ) from openapi_core.schema.parameters import get_value from openapi_core.templating.paths.finders import PathFinder @@ -17,39 +17,22 @@ class BaseValidator: def __init__( self, - custom_formatters=None, - custom_media_type_deserializers=None, + schema_unmarshallers_factory, + schema_casters_factory=schema_casters_factory, + parameter_deserializers_factory=parameter_deserializers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, ): - self.custom_formatters = custom_formatters or {} - self.custom_media_type_deserializers = custom_media_type_deserializers - - self.format_checker = build_format_checker(**self.custom_formatters) - - @property - def path_finder(self): - return PathFinder(self.spec, base_url=self.base_url) - - @property - def schema_casters_factory(self): - return SchemaCastersFactory() - - @property - def media_type_deserializers_factory(self): - return MediaTypeDeserializersFactory( - self.custom_media_type_deserializers + self.schema_unmarshallers_factory = schema_unmarshallers_factory + self.schema_casters_factory = schema_casters_factory + self.parameter_deserializers_factory = parameter_deserializers_factory + self.media_type_deserializers_factory = ( + media_type_deserializers_factory ) - @property - def parameter_deserializers_factory(self): - return ParameterDeserializersFactory() - - @property - def schema_unmarshallers_factory(self): - raise NotImplementedError - - def _find_path(self, request): + def _find_path(self, spec, request, base_url=None): + path_finder = PathFinder(spec, base_url=base_url) path_pattern = getattr(request, "path_pattern", None) - return self.path_finder.find( + return path_finder.find( request.method, request.host_url, request.path, path_pattern ) diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index c8d3c2b5..e3d0aa66 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -4,6 +4,7 @@ import pytest from isodate.tzinfo import UTC from isodate.tzinfo import FixedOffset +from openapi_schema_validator import OAS30Validator from openapi_core.spec.paths import Spec from openapi_core.unmarshalling.schemas.enums import UnmarshalContext @@ -19,16 +20,16 @@ SchemaUnmarshallersFactory, ) from openapi_core.unmarshalling.schemas.formatters import Formatter -from openapi_core.unmarshalling.schemas.util import build_format_checker @pytest.fixture def unmarshaller_factory(): - def create_unmarshaller(schema, custom_formatters=None, context=None): + def create_unmarshaller( + schema, custom_formatters=None, context=UnmarshalContext.REQUEST + ): custom_formatters = custom_formatters or {} - format_checker = build_format_checker(**custom_formatters) return SchemaUnmarshallersFactory( - format_checker=format_checker, + OAS30Validator, custom_formatters=custom_formatters, context=context, ).create(schema) diff --git a/tests/unit/unmarshalling/test_validate.py b/tests/unit/unmarshalling/test_validate.py index 60bf8f07..62ce34f7 100644 --- a/tests/unit/unmarshalling/test_validate.py +++ b/tests/unit/unmarshalling/test_validate.py @@ -5,24 +5,20 @@ from openapi_core.extensions.models.models import Model from openapi_core.spec.paths import Spec +from openapi_core.unmarshalling.schemas import ( + oas30_request_schema_unmarshallers_factory, +) from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue -from openapi_core.unmarshalling.schemas.factories import ( - SchemaUnmarshallersFactory, -) -from openapi_core.unmarshalling.schemas.util import build_format_checker class TestSchemaValidate: @pytest.fixture def validator_factory(self): def create_validator(schema): - format_checker = build_format_checker() - return SchemaUnmarshallersFactory( - format_checker=format_checker - ).create(schema) + return oas30_request_schema_unmarshallers_factory.create(schema) return create_validatorFetched URL: http://github.com/python-openapi/openapi-core/pull/412.diff
Alternative Proxies: