Skip to content

Commit

Permalink
Begin making the patch handler maintainable
Browse files Browse the repository at this point in the history
  • Loading branch information
lavalamp committed Apr 18, 2018
1 parent c8cded5 commit 953955b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 30 deletions.
74 changes: 46 additions & 28 deletions staging/src/k8s.io/apiserver/pkg/endpoints/handlers/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,18 @@ func PatchResource(r rest.Patcher, scope RequestScope, admit admission.Interface
return nil
}

result, err := patchResource(
p := patcher{
codec: codec,
namer: scope.Namer,
creater: scope.Creater,
defaulter: scope.Defaulter,
unsafeConvertor: scope.Convertor,
kind: scope.Kind,
resource: scope.Resource,
trace: trace,
}

result, err := p.patchResource(
ctx,
updateMutation,
rest.AdmissionToValidateObjectFunc(admit, staticAdmissionAttributes),
Expand All @@ -125,7 +136,7 @@ func PatchResource(r rest.Patcher, scope RequestScope, admit admission.Interface
name,
patchType,
patchJS,
scope.Namer, scope.Creater, scope.Defaulter, scope.UnsafeConvertor, scope.Kind, scope.Resource, codec, trace)
)
if err != nil {
scope.err(err, w, req)
return
Expand All @@ -149,8 +160,23 @@ func PatchResource(r rest.Patcher, scope RequestScope, admit admission.Interface

type mutateObjectUpdateFunc func(obj, old runtime.Object) error

type patcher struct {

// Pieces of RequestScope
namer ScopeNamer
creater runtime.ObjectCreater
defaulter runtime.ObjectDefaulter
unsafeConvertor runtime.ObjectConvertor
resource schema.GroupVersionResource
kind schema.GroupVersionKind

codec runtime.Codec

trace *utiltrace.Trace
}

// patchResource divides PatchResource for easier unit testing
func patchResource(
func (p *patcher) patchResource(
ctx request.Context,
updateMutation mutateObjectUpdateFunc,
createValidation rest.ValidateObjectFunc,
Expand All @@ -161,14 +187,6 @@ func patchResource(
name string,
patchType types.PatchType,
patchJS []byte,
namer ScopeNamer,
creater runtime.ObjectCreater,
defaulter runtime.ObjectDefaulter,
unsafeConvertor runtime.ObjectConvertor,
kind schema.GroupVersionKind,
resource schema.GroupVersionResource,
codec runtime.Codec,
trace *utiltrace.Trace,
) (runtime.Object, error) {

namespace := request.NamespaceValue(ctx)
Expand All @@ -186,11 +204,11 @@ func patchResource(
// and is given the currently persisted object as input.
applyPatch := func(_ request.Context, _, currentObject runtime.Object) (runtime.Object, error) {
// Make sure we actually have a persisted currentObject
trace.Step("About to apply patch")
p.trace.Step("About to apply patch")
if hasUID, err := hasUID(currentObject); err != nil {
return nil, err
} else if !hasUID {
return nil, errors.NewNotFound(resource.GroupResource(), name)
return nil, errors.NewNotFound(p.resource.GroupResource(), name)
}

currentResourceVersion := ""
Expand All @@ -213,7 +231,7 @@ func patchResource(
// representations.
switch patchType {
case types.JSONPatchType, types.MergePatchType:
originalJS, patchedJS, err := patchObjectJSON(patchType, codec, currentObject, patchJS, objToUpdate, versionedObj)
originalJS, patchedJS, err := patchObjectJSON(patchType, p.codec, currentObject, patchJS, objToUpdate, versionedObj)
if err != nil {
return nil, interpretPatchError(err)
}
Expand Down Expand Up @@ -241,11 +259,11 @@ func patchResource(
case types.StrategicMergePatchType:
// Since the patch is applied on versioned objects, we need to convert the
// current object to versioned representation first.
currentVersionedObject, err := unsafeConvertor.ConvertToVersion(currentObject, kind.GroupVersion())
currentVersionedObject, err := p.unsafeConvertor.ConvertToVersion(currentObject, p.kind.GroupVersion())
if err != nil {
return nil, err
}
versionedObjToUpdate, err := creater.New(kind)
versionedObjToUpdate, err := p.creater.New(p.kind)
if err != nil {
return nil, err
}
Expand All @@ -254,12 +272,12 @@ func patchResource(
if err != nil {
return nil, err
}
if err := strategicPatchObject(codec, defaulter, currentVersionedObject, patchJS, versionedObjToUpdate, versionedObj); err != nil {
if err := strategicPatchObject(p.codec, p.defaulter, currentVersionedObject, patchJS, versionedObjToUpdate, versionedObj); err != nil {
return nil, err
}
// Convert the object back to unversioned.
gvk := kind.GroupKind().WithVersion(runtime.APIVersionInternal)
unversionedObjToUpdate, err := unsafeConvertor.ConvertToVersion(versionedObjToUpdate, gvk.GroupVersion())
gvk := p.kind.GroupKind().WithVersion(runtime.APIVersionInternal)
unversionedObjToUpdate, err := p.unsafeConvertor.ConvertToVersion(versionedObjToUpdate, gvk.GroupVersion())
if err != nil {
return nil, err
}
Expand All @@ -276,7 +294,7 @@ func patchResource(
return patchMap, nil
}
}
if err := checkName(objToUpdate, name, namespace, namer); err != nil {
if err := checkName(objToUpdate, name, namespace, p.namer); err != nil {
return nil, err
}
return objToUpdate, nil
Expand All @@ -292,7 +310,7 @@ func patchResource(

// Since the patch is applied on versioned objects, we need to convert the
// current object to versioned representation first.
currentVersionedObject, err := unsafeConvertor.ConvertToVersion(currentObject, kind.GroupVersion())
currentVersionedObject, err := p.unsafeConvertor.ConvertToVersion(currentObject, p.kind.GroupVersion())
if err != nil {
return nil, err
}
Expand All @@ -310,7 +328,7 @@ func patchResource(
}
} else {
// Compute current patch.
currentObjJS, err := runtime.Encode(codec, currentObject)
currentObjJS, err := runtime.Encode(p.codec, currentObject)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -350,19 +368,19 @@ func patchResource(
return nil, lastConflictErr
}
// Otherwise manufacture one of our own
return nil, errors.NewConflict(resource.GroupResource(), name, patchDiffErr)
return nil, errors.NewConflict(p.resource.GroupResource(), name, patchDiffErr)
}

versionedObjToUpdate, err := creater.New(kind)
versionedObjToUpdate, err := p.creater.New(p.kind)
if err != nil {
return nil, err
}
if err := applyPatchToObject(codec, defaulter, currentObjMap, originalPatchMap, versionedObjToUpdate, versionedObj); err != nil {
if err := applyPatchToObject(p.codec, p.defaulter, currentObjMap, originalPatchMap, versionedObjToUpdate, versionedObj); err != nil {
return nil, err
}
// Convert the object back to unversioned.
gvk := kind.GroupKind().WithVersion(runtime.APIVersionInternal)
objToUpdate, err := unsafeConvertor.ConvertToVersion(versionedObjToUpdate, gvk.GroupVersion())
gvk := p.kind.GroupKind().WithVersion(runtime.APIVersionInternal)
objToUpdate, err := p.unsafeConvertor.ConvertToVersion(versionedObjToUpdate, gvk.GroupVersion())
if err != nil {
return nil, err
}
Expand All @@ -374,7 +392,7 @@ func patchResource(
// applyAdmission is called every time GuaranteedUpdate asks for the updated object,
// and is given the currently persisted object and the patched object as input.
applyAdmission := func(ctx request.Context, patchedObject runtime.Object, currentObject runtime.Object) (runtime.Object, error) {
trace.Step("About to check admission control")
p.trace.Step("About to check admission control")
return patchedObject, updateMutation(patchedObject, currentObject)
}
updatedObjectInfo := rest.DefaultUpdatedObjectInfo(nil, applyPatch, applyAdmission)
Expand Down
15 changes: 13 additions & 2 deletions staging/src/k8s.io/apiserver/pkg/endpoints/handlers/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,18 @@ func (tc *patchTestCase) Run(t *testing.T) {

}

resultObj, err := patchResource(
p := patcher{
codec: codec,
namer: namer,
creater: creater,
defaulter: defaulter,
unsafeConvertor: convertor,
kind: kind,
resource: resource,
trace: utiltrace.New("Patch" + name),
}

resultObj, err := p.patchResource(
ctx,
admissionMutation,
rest.ValidateAllObjectFunc,
Expand All @@ -348,7 +359,7 @@ func (tc *patchTestCase) Run(t *testing.T) {
name,
patchType,
patch,
namer, creater, defaulter, convertor, kind, resource, codec, utiltrace.New("Patch"+name))
)
if len(tc.expectedError) != 0 {
if err == nil || err.Error() != tc.expectedError {
t.Errorf("%s: expected error %v, but got %v", tc.name, tc.expectedError, err)
Expand Down

0 comments on commit 953955b

Please sign in to comment.