Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Python codegen for the OpenAPI backend #21316

Merged
merged 21 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add a test for python-fastapi generator
  • Loading branch information
grihabor committed Dec 21, 2024
commit 47228fe397e91786bebe596890a3041c61a8e665
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from __future__ import annotations

from pants.backend.experimental.python.register import rules as python_rules
from pants.backend.experimental.python.register import target_types as python_target_types
from pants.backend.experimental.openapi.register import rules as openapi_rules
from pants.backend.experimental.openapi.register import target_types as openapi_target_types
from pants.backend.experimental.python.register import rules as python_rules
from pants.backend.experimental.python.register import target_types as python_target_types
from pants.backend.openapi.codegen.python.rules import rules as openapi_python_codegen_rules


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
import itertools

from pants.backend.openapi.target_types import OpenApiDocumentGeneratorTarget, OpenApiDocumentTarget
from pants.backend.python.target_types import PrefixedPythonResolveField
from pants.engine.target import BoolField, DictStringToStringField, StringField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ async def infer_openapi_python_dependencies(
addresses, missing_requirements = [], []
for runtime_dependency in compiled_sources.runtime_dependencies:
project_name = runtime_dependency.project_name
address = requirements_to_addresses.get(project_name)
address = requirements_to_addresses.get(project_name.lower())
if address is not None:
addresses.append(address)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _assert_generated_files(
HydratedSources, [HydrateSourcesRequest(tgt[OpenApiDocumentField])]
)
generated_sources = rule_runner.request(
GeneratedSources, [GeneratePythonFromOpenAPIRequest(protocol_sources.snapshot, tgt)]
GeneratedSources,
[GeneratePythonFromOpenAPIRequest(protocol_sources.snapshot, tgt)],
)

# We only assert expected files are a subset of all generated since the generator creates a lot of support classes
Expand Down Expand Up @@ -191,6 +192,103 @@ def assert_gen(address: Address, expected: Iterable[str]) -> None:
assert runtime_dependencies


@pytest.fixture
def fastapi_requirements_text() -> str:
return dedent(
"""\
python_requirement(name="jinja2", requirements=["jinja2"])
python_requirement(name="markupsafe", requirements=["markupsafe"])
python_requirement(name="pyyaml", requirements=["pyyaml"])
python_requirement(name="rx", requirements=["rx"])
python_requirement(name="aiofiles", requirements=["aiofiles"])
python_requirement(name="aniso8601", requirements=["aniso8601"])
python_requirement(name="async-exit-stack", requirements=["async-exit-stack"])
python_requirement(name="async-generator", requirements=["async-generator"])
python_requirement(name="certifi", requirements=["certifi"])
python_requirement(name="chardet", requirements=["chardet"])
python_requirement(name="click", requirements=["click"])
python_requirement(name="dnspython", requirements=["dnspython"])
python_requirement(name="email-validator", requirements=["email-validator"])
python_requirement(name="fastapi", requirements=["fastapi"])
python_requirement(name="graphene", requirements=["graphene"])
python_requirement(name="graphql-core", requirements=["graphql-core"])
python_requirement(name="graphql-relay", requirements=["graphql-relay"])
python_requirement(name="h11", requirements=["h11"])
python_requirement(name="httptools", requirements=["httptools"])
python_requirement(name="idna", requirements=["idna"])
python_requirement(name="itsdangerous", requirements=["itsdangerous"])
python_requirement(name="orjson", requirements=["orjson"])
python_requirement(name="promise", requirements=["promise"])
python_requirement(name="pydantic", requirements=["pydantic"])
python_requirement(name="python-dotenv", requirements=["python-dotenv"])
python_requirement(name="python-multipart", requirements=["python-multipart"])
python_requirement(name="requests", requirements=["requests"])
python_requirement(name="six", requirements=["six"])
python_requirement(name="starlette", requirements=["starlette"])
python_requirement(name="typing-extensions", requirements=["typing-extensions"])
python_requirement(name="ujson", requirements=["ujson"])
python_requirement(name="urllib3", requirements=["urllib3"])
python_requirement(name="uvicorn", requirements=["uvicorn"])
python_requirement(name="uvloop", requirements=["uvloop"])
python_requirement(name="watchgod", requirements=["watchgod"])
python_requirement(name="websockets", requirements=["websockets"])
"""
)


def test_generate_python_sources_with_a_different_generator(
rule_runner: RuleRunner, fastapi_requirements_text: str
) -> None:
rule_runner.write_files(
{
"3rdparty/python/default.lock": resources.files(__package__)
.joinpath("openapi.test.lock")
.read_text(),
"3rdparty/python/BUILD": fastapi_requirements_text,
"src/openapi/BUILD": dedent(
"""\
openapi_document(
name="petstore",
source="petstore_spec.yaml",
python_generator_name="python-fastapi",
)
"""
),
"src/openapi/petstore_spec.yaml": PETSTORE_SAMPLE_SPEC,
}
)

def assert_gen(address: Address, expected: Iterable[str]) -> None:
_assert_generated_files(
rule_runner, address, source_roots=["src/openapi"], expected_files=expected
)

tgt_address = Address("src/openapi", target_name="petstore")
assert_gen(
tgt_address,
[
# The list might change because it depends on openapi template.
# TODO Vendor template?
"src/openapi/src/openapi_server/apis/__init__.py",
"src/openapi/src/openapi_server/apis/pets_api.py",
"src/openapi/src/openapi_server/main.py",
"src/openapi/src/openapi_server/models/__init__.py",
"src/openapi/src/openapi_server/models/error.py",
"src/openapi/src/openapi_server/models/extra_models.py",
"src/openapi/src/openapi_server/models/pet.py",
"src/openapi/src/openapi_server/security_api.py",
"src/openapi/tests/conftest.py",
"src/openapi/tests/test_pets_api.py",
],
)

tgt = rule_runner.get_target(tgt_address)
runtime_dependencies = rule_runner.request(
Addresses, [DependenciesRequest(tgt[OpenApiDocumentDependenciesField])]
)
assert runtime_dependencies


def test_generate_python_sources_using_custom_package_name(
rule_runner: RuleRunner,
requirements_text: str,
Expand Down
39 changes: 38 additions & 1 deletion src/python/pants/backend/openapi/util_rules/generator_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from pants.backend.openapi.subsystems import openapi_generator
from pants.backend.openapi.subsystems.openapi_generator import OpenAPIGenerator
from pants.engine.fs import Digest
from pants.engine.process import Process, ProcessCacheScope
from pants.engine.internals.native_engine import EMPTY_DIGEST
from pants.engine.process import Process, ProcessCacheScope, ProcessResult
from pants.engine.rules import Get, collect_rules, rule
from pants.jvm import jdk_rules, non_jvm_dependencies
from pants.jvm.jdk_rules import InternalJdk, JvmProcess
Expand Down Expand Up @@ -72,6 +73,42 @@ def __init__(
_GENERATOR_CLASS_NAME = "org.openapitools.codegen.OpenAPIGenerator"


@dataclass(frozen=True)
class OpenAPIGeneratorNames:
names: tuple[str, ...]


@rule
async def get_openapi_generator_names(
subsystem: OpenAPIGenerator, jdk: InternalJdk
) -> OpenAPIGeneratorNames:
tool_classpath = await Get(
ToolClasspath, ToolClasspathRequest(lockfile=GenerateJvmLockfileFromTool.create(subsystem))
)

toolcp_relpath = "__toolcp"
immutable_input_digests = {
toolcp_relpath: tool_classpath.digest,
}

classpath_entries = [
*tool_classpath.classpath_entries(toolcp_relpath),
]

jvm_process = JvmProcess(
jdk=jdk,
argv=[_GENERATOR_CLASS_NAME, "list", "-a"],
classpath_entries=classpath_entries,
input_digest=EMPTY_DIGEST,
extra_immutable_input_digests=immutable_input_digests,
extra_jvm_options=subsystem.jvm_options,
description="Get openapi generator names.",
cache_scope=ProcessCacheScope.SUCCESSFUL,
)
result = await Get(ProcessResult, JvmProcess, jvm_process)
return OpenAPIGeneratorNames(names=tuple(result.stdout.decode("utf-8").split()))


@rule
async def openapi_generator_process(
request: OpenAPIGeneratorProcess, jdk: InternalJdk, subsystem: OpenAPIGenerator
Expand Down
Loading