Skip to content

Commit

Permalink
Improve CEL cost tests to catch unhandled estimates or types
Browse files Browse the repository at this point in the history
  • Loading branch information
liggitt committed Jul 19, 2024
1 parent 92e3445 commit 1d2ad28
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
41 changes: 39 additions & 2 deletions staging/src/k8s.io/apiserver/pkg/cel/library/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package library

import (
"fmt"
"math"

"github.com/google/cel-go/checker"
Expand All @@ -25,9 +26,28 @@ import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"

"k8s.io/apiserver/pkg/cel"
)

// panicOnUnknown makes cost estimate functions panic on unrecognized functions.
// This is only set to true for unit tests.
var panicOnUnknown = false

// builtInFunctions is a list of functions used in cost tests that are not handled by CostEstimator.
var knownUnhandledFunctions = map[string]bool{
"uint": true,
"duration": true,
"bytes": true,
"timestamp": true,
"value": true,
"_==_": true,
"_&&_": true,
"_>_": true,
"!_": true,
"strings.quote": true,
}

// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
type CostEstimator struct {
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
Expand Down Expand Up @@ -106,7 +126,7 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
// IP and CIDR accessors are nominal cost.
cost := uint64(1)
return &cost
Expand Down Expand Up @@ -185,6 +205,13 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
cost := uint64(1)
return &cost
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
cost := uint64(1)
return &cost
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
}
return nil
}
Expand Down Expand Up @@ -359,7 +386,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// So we double the cost of parsing the string.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)}
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
// IP and CIDR accessors are nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "containsIP":
Expand Down Expand Up @@ -414,6 +441,12 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))
}
return nil
}
Expand All @@ -422,6 +455,10 @@ func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
if panicOnUnknown {
// debug.PrintStack()
panic(fmt.Errorf("actualSize: non-sizer type %T", value))
}
return 1
}

Expand Down
9 changes: 9 additions & 0 deletions staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,10 @@ func TestSetsCost(t *testing.T) {
}

func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
originalPanicOnUnknown := panicOnUnknown
panicOnUnknown = true
t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown })

est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
env, err := cel.NewEnv(
ext.Strings(ext.StringsVersion(2)),
Expand Down Expand Up @@ -1168,6 +1172,11 @@ func TestSize(t *testing.T) {
expectSize: checker.SizeEstimate{Min: 2, Max: 4},
},
}

originalPanicOnUnknown := panicOnUnknown
panicOnUnknown = true
t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown })

est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
Expand Down

0 comments on commit 1d2ad28

Please sign in to comment.