Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate a single swagger.json file for all frameworks #1437

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions hack/python-sdk/gen-sdk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ SWAGGER_JAR_URL="https://repo1.maven.org/maven2/org/openapitools/openapi-generat
SWAGGER_CODEGEN_JAR="${repo_root}/hack/python-sdk/openapi-generator-cli.jar"
SWAGGER_CODEGEN_CONF="${repo_root}/hack/python-sdk/swagger_config.json"
SDK_OUTPUT_PATH="${repo_root}/sdk/python"
FRAMEWORKS=(tensorflow pytorch mxnet xgboost)
VERSION=1.3.0
SWAGGER_CODEGEN_FILE="${repo_root}/hack/python-sdk/swagger.json"

if [ -z "${GOPATH:-}" ]; then
export GOPATH=$(go env GOPATH)
Expand All @@ -39,32 +39,15 @@ if [[ ! -f "$SWAGGER_CODEGEN_JAR" ]]; then
wget -O "${SWAGGER_CODEGEN_JAR}" ${SWAGGER_JAR_URL}
fi


for FRAMEWORK in ${FRAMEWORKS[@]}; do
SWAGGER_CODEGEN_FILE="pkg/apis/${FRAMEWORK}/v1/swagger.json"
echo "Generating swagger file for ${FRAMEWORK} ..."
go run "${repo_root}"/hack/python-sdk/main.go "${FRAMEWORK}" ${VERSION} > "${SWAGGER_CODEGEN_FILE}"
done

echo "Merging swagger files from different frameworks into one"
download_url=$(curl -s https://api.github.com/repos/go-swagger/go-swagger/releases/latest | \
jq -r '.assets[] | select(.name | contains("'"$(uname | tr '[:upper:]' '[:lower:]')"'_amd64")) | .browser_download_url')
curl -o /tmp/swagger -L'#' "$download_url"
chmod +x /tmp/swagger

# it will report warning like 'v1.SchedulingPolicy' already exists in primary or higher priority mixin, skipping
# error code is not 0 but t's acceptable.
/tmp/swagger mixin "${repo_root}"/pkg/apis/tensorflow/v1/swagger.json "${repo_root}"/pkg/apis/pytorch/v1/swagger.json \
"${repo_root}"/pkg/apis/mxnet/v1/swagger.json "${repo_root}"/pkg/apis/xgboost/v1/swagger.json \
--output "${repo_root}"/hack/python-sdk/swagger.json --quiet || true
echo "Generating swagger file ..."
go run "${repo_root}"/hack/python-sdk/main.go ${VERSION} > "${SWAGGER_CODEGEN_FILE}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You re-run this script to generate the new Swagger, right ?

Copy link
Member Author

@alembiewski alembiewski Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct. It has produced the file with the same contents as before, except now it has a proper title and description, so, effectively, no changes. And I also regenerated the SDK using the updated spec


echo "Removing previously generated files ..."
rm -rf "${SDK_OUTPUT_PATH}"/docs/V1*.md "${SDK_OUTPUT_PATH}"/kubeflow/training/models "${SDK_OUTPUT_PATH}"/kubeflow/training/*.py "${SDK_OUTPUT_PATH}"/test/*.py
echo "Generating Python SDK for Training Operator ..."
java -jar "${SWAGGER_CODEGEN_JAR}" generate -i "${repo_root}"/hack/python-sdk/swagger.json -g python -o "${SDK_OUTPUT_PATH}" -c "${SWAGGER_CODEGEN_CONF}"

echo "Kubeflow Training Operator Python SDK is generated successfully to folder ${SDK_OUTPUT_PATH}/."
rm /tmp/swagger

echo "Running post-generation script ..."
"${repo_root}"/hack/python-sdk/post_gen.py
71 changes: 37 additions & 34 deletions hack/python-sdk/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 kubeflow.org.
Copyright 2021 kubeflow.org.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -19,52 +19,52 @@ package main
import (
"encoding/json"
"fmt"
mxnet "github.com/kubeflow/tf-operator/pkg/apis/mxnet/v1"
pytorch "github.com/kubeflow/tf-operator/pkg/apis/pytorch/v1"
tensorflow "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
xgboost "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1"
"os"
"strings"

"github.com/go-openapi/spec"
mxJob "github.com/kubeflow/tf-operator/pkg/apis/mxnet/v1"
pytorchJob "github.com/kubeflow/tf-operator/pkg/apis/pytorch/v1"
tfjob "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
xgboostJob "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1"
"k8s.io/klog"
"k8s.io/kube-openapi/pkg/common"
)

// Generate OpenAPI spec definitions for TFJob Resource
// Generate OpenAPI spec definitions for API resources
func main() {
if len(os.Args) <= 2 {
klog.Fatal("Supply a framework and version")
if len(os.Args) <= 1 {
klog.Fatal("Supply a version")
}
framework := os.Args[1]
version := os.Args[2]
version := os.Args[1]
if !strings.HasPrefix(version, "v") {
version = "v" + version
}
var oAPIDefs map[string]common.OpenAPIDefinition
var oAPIDefs = map[string]common.OpenAPIDefinition{}
defs := spec.Definitions{}

switch framework {
case "tensorflow":
oAPIDefs = tfjob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
case "pytorch":
oAPIDefs = pytorchJob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
case "mxnet":
oAPIDefs = mxJob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
case "xgboost":
oAPIDefs = xgboostJob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
refCallback := func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name)))
}

for k, v := range tensorflow.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

for k, v := range pytorch.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

for k, v := range mxnet.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

for k, v := range xgboost.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

defs := spec.Definitions{}
for defName, val := range oAPIDefs {
defs[swaggify(defName, framework)] = val.Schema
defs[swaggify(defName)] = val.Schema
}
swagger := spec.Swagger{
SwaggerProps: spec.SwaggerProps{
Expand All @@ -73,8 +73,8 @@ func main() {
Paths: &spec.Paths{Paths: map[string]spec.PathItem{}},
Info: &spec.Info{
InfoProps: spec.InfoProps{
Title: framework,
Description: fmt.Sprintf("Python SDK for %v", framework),
Title: "Kubeflow Training SDK",
Description: "Python SDK for Kubeflow Training",
Version: version,
},
},
Expand All @@ -87,8 +87,11 @@ func main() {
fmt.Println(string(jsonBytes))
}

func swaggify(name, framework string) string {
name = strings.Replace(name, fmt.Sprintf("github.com/kubeflow/tf-operator/pkg/apis/%s/", framework), "", -1)
func swaggify(name string) string {
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/pytorch/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/mxnet/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/xgboost/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/common/pkg/apis/common/", "", -1)
name = strings.Replace(name, "k8s.io/api/core/", "", -1)
name = strings.Replace(name, "k8s.io/apimachinery/pkg/apis/meta/", "", -1)
Expand Down
Loading