Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Python codegen for the OpenAPI backend #21316

Merged
merged 21 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add test_openapi_generator_name_validation
  • Loading branch information
grihabor committed Dec 21, 2024
commit e25726d94edc21464b1e05a628be1cc8eaf63e80
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pants.core.goals.test import rules as test_rules
from pants.core.util_rules.config_files import rules as config_files_rules
from pants.engine.addresses import Address, Addresses
from pants.engine.internals.scheduler import ExecutionError
from pants.engine.target import (
Dependencies,
DependenciesRequest,
Expand Down Expand Up @@ -63,14 +64,13 @@ def rule_runner() -> RuleRunner:
return rule_runner


def _assert_generated_files(
def _get_generated_files(
rule_runner: RuleRunner,
address: Address,
*,
expected_files: Iterable[str],
source_roots: Iterable[str] | None = None,
extra_args: Iterable[str] = (),
) -> None:
) -> tuple[str, ...]:
args = []
if source_roots:
args.append(f"--source-root-patterns={repr(source_roots)}")
Expand All @@ -85,9 +85,7 @@ def _assert_generated_files(
GeneratedSources,
[GeneratePythonFromOpenAPIRequest(protocol_sources.snapshot, tgt)],
)

# We only assert expected files are a subset of all generated since the generator creates a lot of support classes
assert set(generated_sources.snapshot.files) == set(expected_files)
return generated_sources.snapshot.files


def test_skip_generate_python(rule_runner: RuleRunner) -> None:
Expand All @@ -99,7 +97,9 @@ def test_skip_generate_python(rule_runner: RuleRunner) -> None:
)

def assert_gen(address: Address, expected: Iterable[str]) -> None:
_assert_generated_files(rule_runner, address, expected_files=expected)
generated_files = _get_generated_files(rule_runner, address)
# We only assert expected files are a subset of all generated since the generator creates a lot of support classes
assert set(generated_files) == set(expected)

tgt_address = Address("", target_name="petstore")
assert_gen(tgt_address, [])
Expand Down Expand Up @@ -152,9 +152,9 @@ def test_generate_python_sources(rule_runner: RuleRunner, requirements_text: str
)

def assert_gen(address: Address, expected: Iterable[str]) -> None:
_assert_generated_files(
rule_runner, address, source_roots=["src/openapi"], expected_files=expected
)
generated_files = _get_generated_files(rule_runner, address, source_roots=["src/openapi"])
# We only assert expected files are a subset of all generated since the generator creates a lot of support classes
assert set(generated_files) == set(expected)

tgt_address = Address("src/openapi", target_name="petstore")
assert_gen(
Expand Down Expand Up @@ -259,9 +259,9 @@ def test_generate_python_sources_with_a_different_generator(
)

def assert_gen(address: Address, expected: Iterable[str]) -> None:
_assert_generated_files(
rule_runner, address, source_roots=["src/openapi"], expected_files=expected
)
generated_files = _get_generated_files(rule_runner, address, source_roots=["src/openapi"])
# We only assert expected files are a subset of all generated since the generator creates a lot of support classes
assert set(generated_files) == set(expected)

tgt_address = Address("src/openapi", target_name="petstore")
assert_gen(
Expand Down Expand Up @@ -289,6 +289,34 @@ def assert_gen(address: Address, expected: Iterable[str]) -> None:
assert runtime_dependencies


def test_openapi_generator_name_validation(rule_runner: RuleRunner, requirements_text: str):
rule_runner.write_files(
{
"3rdparty/python/default.lock": resources.files(__package__)
.joinpath("openapi.test.lock")
.read_text(),
"3rdparty/python/BUILD": requirements_text,
"src/openapi/BUILD": dedent(
"""\
openapi_document(
name="petstore",
source="petstore_spec.yaml",
python_generator_name="python-xxx",
)
"""
),
"src/openapi/petstore_spec.yaml": PETSTORE_SAMPLE_SPEC,
}
)

address = Address("src/openapi", target_name="petstore")
with pytest.raises(
ExecutionError,
match="ValueError: OpenAPI generator `python-xxx` is not found, available generators: ",
):
_get_generated_files(rule_runner, address, source_roots=["src/openapi"])


def test_generate_python_sources_using_custom_package_name(
rule_runner: RuleRunner,
requirements_text: str,
Expand Down Expand Up @@ -316,9 +344,9 @@ def test_generate_python_sources_using_custom_package_name(
)

def assert_gen(address: Address, expected: Iterable[str]) -> None:
_assert_generated_files(
rule_runner, address, source_roots=["src/openapi"], expected_files=expected
)
generated_files = _get_generated_files(rule_runner, address, source_roots=["src/openapi"])
# We only assert expected files are a subset of all generated since the generator creates a lot of support classes
assert set(generated_files) == set(expected)

assert_gen(
Address("src/openapi", target_name="petstore"),
Expand Down
25 changes: 21 additions & 4 deletions src/python/pants/backend/openapi/util_rules/generator_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from __future__ import annotations

import dataclasses
import re
from dataclasses import dataclass
from typing import Iterable, Mapping
from typing import Iterable, Iterator, Mapping

from pants.backend.openapi.subsystems import openapi_generator
from pants.backend.openapi.subsystems.openapi_generator import OpenAPIGenerator
Expand Down Expand Up @@ -78,6 +79,13 @@ class OpenAPIGeneratorNames:
names: tuple[str, ...]


def _parse_names(stdout: str) -> Iterator[str]:
regex = re.compile(r"^ *- (?P<name>[^ ]+)")
for line in stdout.splitlines():
if (match := regex.match(line)) is not None:
yield match.group("name")


@rule
async def get_openapi_generator_names(
subsystem: OpenAPIGenerator, jdk: InternalJdk
Expand All @@ -97,7 +105,7 @@ async def get_openapi_generator_names(

jvm_process = JvmProcess(
jdk=jdk,
argv=[_GENERATOR_CLASS_NAME, "list", "-a"],
argv=[_GENERATOR_CLASS_NAME, "list"],
classpath_entries=classpath_entries,
input_digest=EMPTY_DIGEST,
extra_immutable_input_digests=immutable_input_digests,
Expand All @@ -106,13 +114,22 @@ async def get_openapi_generator_names(
cache_scope=ProcessCacheScope.SUCCESSFUL,
)
result = await Get(ProcessResult, JvmProcess, jvm_process)
return OpenAPIGeneratorNames(names=tuple(result.stdout.decode("utf-8").split()))
return OpenAPIGeneratorNames(names=tuple(_parse_names(result.stdout.decode("utf-8"))))


@rule
async def openapi_generator_process(
request: OpenAPIGeneratorProcess, jdk: InternalJdk, subsystem: OpenAPIGenerator
request: OpenAPIGeneratorProcess,
jdk: InternalJdk,
subsystem: OpenAPIGenerator,
generator_names: OpenAPIGeneratorNames,
) -> Process:
if request.generator_name not in generator_names.names:
names = ", ".join(f"`{name}`" for name in generator_names.names)
raise ValueError(
f"OpenAPI generator `{request.generator_name}` is not found, available generators: {names}"
)

tool_classpath = await Get(
ToolClasspath, ToolClasspathRequest(lockfile=GenerateJvmLockfileFromTool.create(subsystem))
)
Expand Down
Loading