Skip to content

Commit 45bbe0a

Browse files
authored
Merge pull request #291 from python-openapi/feature/versions-submodule
Versions submodule
2 parents 74a6c5f + db016fa commit 45bbe0a

File tree

9 files changed

+152
-31
lines changed

9 files changed

+152
-31
lines changed

openapi_spec_validator/shortcuts.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,29 @@
1010
from openapi_spec_validator.validation import OpenAPIV2SpecValidator
1111
from openapi_spec_validator.validation import OpenAPIV30SpecValidator
1212
from openapi_spec_validator.validation import OpenAPIV31SpecValidator
13-
from openapi_spec_validator.validation.finders import SpecFinder
14-
from openapi_spec_validator.validation.finders import SpecVersion
13+
from openapi_spec_validator.validation.exceptions import ValidatorDetectError
1514
from openapi_spec_validator.validation.protocols import SupportsValidation
1615
from openapi_spec_validator.validation.types import SpecValidatorType
1716
from openapi_spec_validator.validation.validators import SpecValidator
18-
19-
SPECS: Mapping[SpecVersion, SpecValidatorType] = {
20-
SpecVersion("swagger", "2.0"): OpenAPIV2SpecValidator,
21-
SpecVersion("openapi", "3.0"): OpenAPIV30SpecValidator,
22-
SpecVersion("openapi", "3.1"): OpenAPIV31SpecValidator,
17+
from openapi_spec_validator.versions import consts as versions
18+
from openapi_spec_validator.versions.datatypes import SpecVersion
19+
from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound
20+
from openapi_spec_validator.versions.shortcuts import get_spec_version
21+
22+
SPEC2VALIDATOR: Mapping[SpecVersion, SpecValidatorType] = {
23+
versions.OPENAPIV2: OpenAPIV2SpecValidator,
24+
versions.OPENAPIV30: OpenAPIV30SpecValidator,
25+
versions.OPENAPIV31: OpenAPIV31SpecValidator,
2326
}
2427

2528

2629
def get_validator_cls(spec: Schema) -> SpecValidatorType:
27-
return SpecFinder(SPECS).find(spec)
30+
try:
31+
spec_version = get_spec_version(spec)
32+
# backward compatibility
33+
except OpenAPIVersionNotFound:
34+
raise ValidatorDetectError
35+
return SPEC2VALIDATOR[spec_version]
2836

2937

3038
def validate_spec(

openapi_spec_validator/validation/finders.py

-23
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from openapi_spec_validator.versions.consts import OPENAPIV2
2+
from openapi_spec_validator.versions.consts import OPENAPIV30
3+
from openapi_spec_validator.versions.consts import OPENAPIV31
4+
from openapi_spec_validator.versions.datatypes import SpecVersion
5+
from openapi_spec_validator.versions.shortcuts import get_spec_version
6+
7+
__all__ = [
8+
"OPENAPIV2",
9+
"OPENAPIV30",
10+
"OPENAPIV31",
11+
"SpecVersion",
12+
"get_spec_version",
13+
]
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import List
2+
3+
from openapi_spec_validator.versions.datatypes import SpecVersion
4+
5+
OPENAPIV2 = SpecVersion(
6+
keyword="swagger",
7+
major="2",
8+
minor="0",
9+
)
10+
11+
OPENAPIV30 = SpecVersion(
12+
keyword="openapi",
13+
major="3",
14+
minor="0",
15+
)
16+
17+
OPENAPIV31 = SpecVersion(
18+
keyword="openapi",
19+
major="3",
20+
minor="1",
21+
)
22+
23+
VERSIONS: List[SpecVersion] = [OPENAPIV2, OPENAPIV30, OPENAPIV31]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass(frozen=True)
5+
class SpecVersion:
6+
"""
7+
Spec version designates the OAS feature set.
8+
"""
9+
10+
keyword: str
11+
major: str
12+
minor: str
13+
14+
def __str__(self) -> str:
15+
return f"OpenAPIV{self.major}.{self.minor}"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from openapi_spec_validator.exceptions import OpenAPIError
2+
3+
4+
class OpenAPIVersionNotFound(OpenAPIError):
5+
def __str__(self) -> str:
6+
return "Specification version not found"
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from re import compile
2+
from typing import List
3+
4+
from jsonschema_spec.typing import Schema
5+
6+
from openapi_spec_validator.versions.datatypes import SpecVersion
7+
from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound
8+
9+
10+
class SpecVersionFinder:
11+
pattern = compile(r"(?P<major>\d+)\.(?P<minor>\d+)(\..*)?")
12+
13+
def __init__(self, versions: List[SpecVersion]) -> None:
14+
self.versions = versions
15+
16+
def find(self, spec: Schema) -> SpecVersion:
17+
for v in self.versions:
18+
if v.keyword in spec:
19+
version_str = spec[v.keyword]
20+
m = self.pattern.match(version_str)
21+
if m:
22+
version = SpecVersion(**m.groupdict(), keyword=v.keyword)
23+
if v == version:
24+
return v
25+
26+
raise OpenAPIVersionNotFound
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from jsonschema_spec.typing import Schema
2+
3+
from openapi_spec_validator.versions.consts import VERSIONS
4+
from openapi_spec_validator.versions.datatypes import SpecVersion
5+
from openapi_spec_validator.versions.finders import SpecVersionFinder
6+
7+
8+
def get_spec_version(spec: Schema) -> SpecVersion:
9+
finder = SpecVersionFinder(VERSIONS)
10+
return finder.find(spec)

tests/integration/test_versions.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
3+
from openapi_spec_validator.versions import consts as versions
4+
from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound
5+
from openapi_spec_validator.versions.shortcuts import get_spec_version
6+
7+
8+
class TestGetSpecVersion:
9+
def test_no_keyword(self):
10+
spec = {}
11+
12+
with pytest.raises(OpenAPIVersionNotFound):
13+
get_spec_version(spec)
14+
15+
@pytest.mark.parametrize("keyword", ["swagger", "openapi"])
16+
@pytest.mark.parametrize("version", ["x.y.z", "xyz2.0.0", "2.xyz0.0"])
17+
def test_invalid(self, keyword, version):
18+
spec = {
19+
keyword: version,
20+
}
21+
22+
with pytest.raises(OpenAPIVersionNotFound):
23+
get_spec_version(spec)
24+
25+
@pytest.mark.parametrize(
26+
"keyword,version,expected",
27+
[
28+
("swagger", "2.0", versions.OPENAPIV2),
29+
("openapi", "3.0.0", versions.OPENAPIV30),
30+
("openapi", "3.0.1", versions.OPENAPIV30),
31+
("openapi", "3.0.2", versions.OPENAPIV30),
32+
("openapi", "3.0.3", versions.OPENAPIV30),
33+
("openapi", "3.1.0", versions.OPENAPIV31),
34+
],
35+
)
36+
def test_valid(self, keyword, version, expected):
37+
spec = {
38+
keyword: version,
39+
}
40+
41+
result = get_spec_version(spec)
42+
43+
assert result == expected

0 commit comments

Comments
 (0)