Skip to content

Commit

Permalink
advancedtls: Rename VType (#7149)
Browse files Browse the repository at this point in the history
* renamed VType to VerificationType and add deprecation note
  • Loading branch information
gtcooke94 authored Apr 19, 2024
1 parent 09e6fdd commit 5fe2e74
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 232 deletions.
42 changes: 33 additions & 9 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,13 @@ type ClientOptions struct {
// RootOptions is OPTIONAL on client side. If not set, we will try to use the
// default trust certificates in users' OS system.
RootOptions RootCertificateOptions
// VerificationType defines what type of server verification is done. See
// the `VerificationType` enum for the different options.
// Default: CertAndHostVerification
VerificationType VerificationType
// VType is the verification type on the client side.
//
// Deprecated: use VerificationType instead.
VType VerificationType
// RevocationConfig is the configurations for certificate revocation checks.
// It could be nil if such checks are not needed.
Expand Down Expand Up @@ -210,7 +216,13 @@ type ServerOptions struct {
RootOptions RootCertificateOptions
// If the server want the client to send certificates.
RequireClientCert bool
// VerificationType defines what type of client verification is done. See
// the `VerificationType` enum for the different options.
// Default: CertAndHostVerification
VerificationType VerificationType
// VType is the verification type on the server side.
//
// Deprecated: use VerificationType instead.
VType VerificationType
// RevocationConfig is the configurations for certificate revocation checks.
// It could be nil if such checks are not needed.
Expand All @@ -227,7 +239,13 @@ type ServerOptions struct {
}

func (o *ClientOptions) config() (*tls.Config, error) {
if o.VType == SkipVerification && o.VerifyPeer == nil {
// TODO(gtcooke94). VType is deprecated, eventually remove this block. This
// will ensure that users still explicitly setting `VType` will get the
// setting to the right place.
if o.VType != CertAndHostVerification {
o.VerificationType = o.VType
}
if o.VerificationType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
}
// Make sure users didn't specify more than one fields in
Expand Down Expand Up @@ -271,7 +289,7 @@ func (o *ClientOptions) config() (*tls.Config, error) {
default:
// No root certificate options specified by user. Use the certificates
// stored in system default path as the last resort.
if o.VType != SkipVerification {
if o.VerificationType != SkipVerification {
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
Expand Down Expand Up @@ -303,7 +321,13 @@ func (o *ClientOptions) config() (*tls.Config, error) {
}

func (o *ServerOptions) config() (*tls.Config, error) {
if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
// TODO(gtcooke94). VType is deprecated, eventually remove this block. This
// will ensure that users still explicitly setting `VType` will get the
// setting to the right place.
if o.VType != CertAndHostVerification {
o.VerificationType = o.VType
}
if o.RequireClientCert && o.VerificationType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
}
// Make sure users didn't specify more than one fields in
Expand Down Expand Up @@ -351,7 +375,7 @@ func (o *ServerOptions) config() (*tls.Config, error) {
default:
// No root certificate options specified by user. Use the certificates
// stored in system default path as the last resort.
if o.VType != SkipVerification && o.RequireClientCert {
if o.VerificationType != SkipVerification && o.RequireClientCert {
systemRootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, err
Expand Down Expand Up @@ -395,7 +419,7 @@ type advancedTLSCreds struct {
verifyFunc CustomVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool
vType VerificationType
verificationType VerificationType
revocationConfig *RevocationConfig
}

Expand Down Expand Up @@ -495,7 +519,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
}
rawCertList[i] = cert
}
if c.vType == CertAndHostVerification || c.vType == CertVerification {
if c.verificationType == CertAndHostVerification || c.verificationType == CertVerification {
// perform possible trust credential reloading and certificate check
rootCAs := c.config.RootCAs
if !c.isClient {
Expand Down Expand Up @@ -527,7 +551,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
opts.Intermediates.AddCert(cert)
}
// Perform default hostname check if specified.
if c.isClient && c.vType == CertAndHostVerification && serverName != "" {
if c.isClient && c.verificationType == CertAndHostVerification && serverName != "" {
parsedName, _, err := net.SplitHostPort(serverName)
if err != nil {
// If the serverName had no host port or if the serverName cannot be
Expand Down Expand Up @@ -579,7 +603,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
verificationType: o.VerificationType,
revocationConfig: o.RevocationConfig,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
Expand All @@ -598,7 +622,7 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
verificationType: o.VerificationType,
revocationConfig: o.RevocationConfig,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
Expand Down
104 changes: 52 additions & 52 deletions security/advancedtls/advancedtls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,19 @@ func (s) TestEnd2End(t *testing.T) {
}
stage := &stageInfo{}
for _, test := range []struct {
desc string
clientCert []tls.Certificate
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientVType VerificationType
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVType VerificationType
desc string
clientCert []tls.Certificate
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientVerificationType VerificationType
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVerificationType VerificationType
}{
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
Expand Down Expand Up @@ -178,8 +178,8 @@ func (s) TestEnd2End(t *testing.T) {
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
clientVerificationType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
switch stage.read() {
case 0, 1:
Expand All @@ -191,7 +191,7 @@ func (s) TestEnd2End(t *testing.T) {
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
serverVType: CertVerification,
serverVerificationType: CertVerification,
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
Expand Down Expand Up @@ -219,7 +219,7 @@ func (s) TestEnd2End(t *testing.T) {
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
clientVerificationType: CertVerification,
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
switch stage.read() {
case 0:
Expand All @@ -232,7 +232,7 @@ func (s) TestEnd2End(t *testing.T) {
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
serverVType: CertVerification,
serverVerificationType: CertVerification,
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
Expand Down Expand Up @@ -284,7 +284,7 @@ func (s) TestEnd2End(t *testing.T) {
}
return nil, fmt.Errorf("custom authz check fails")
},
clientVType: CertVerification,
clientVerificationType: CertVerification,
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
switch stage.read() {
case 0:
Expand All @@ -297,7 +297,7 @@ func (s) TestEnd2End(t *testing.T) {
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
serverVType: CertVerification,
serverVerificationType: CertVerification,
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
Expand All @@ -317,9 +317,9 @@ func (s) TestEnd2End(t *testing.T) {
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverRoot: cs.ServerTrust1,
clientVerificationType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
switch stage.read() {
case 0, 2:
Expand All @@ -330,7 +330,7 @@ func (s) TestEnd2End(t *testing.T) {
return nil, fmt.Errorf("custom authz check fails")
}
},
serverVType: CertVerification,
serverVerificationType: CertVerification,
},
} {
test := test
Expand All @@ -347,7 +347,7 @@ func (s) TestEnd2End(t *testing.T) {
},
RequireClientCert: true,
VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType,
VerificationType: test.serverVerificationType,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
if err != nil {
Expand All @@ -373,7 +373,7 @@ func (s) TestEnd2End(t *testing.T) {
RootCACerts: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
},
VType: test.clientVType,
VerificationType: test.clientVerificationType,
}
clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil {
Expand Down Expand Up @@ -638,7 +638,7 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) {
VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
VType: CertVerification,
VerificationType: CertVerification,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
if err != nil {
Expand All @@ -664,7 +664,7 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) {
RootOptions: RootCertificateOptions{
RootProvider: clientRootProvider,
},
VType: CertVerification,
VerificationType: CertVerification,
}
clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil {
Expand Down Expand Up @@ -731,34 +731,34 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
for _, test := range []struct {
desc string
clientRoot *x509.CertPool
clientVType VerificationType
serverCert []tls.Certificate
serverVType VerificationType
expectError bool
desc string
clientRoot *x509.CertPool
clientVerificationType VerificationType
serverCert []tls.Certificate
serverVerificationType VerificationType
expectError bool
}{
// Client side sets vType to CertAndHostVerification, and will do
// default hostname check. Server uses a cert without "localhost" or
// "127.0.0.1" as common name or SAN names, and will hence fail.
{
desc: "Bad default hostname check",
clientRoot: cs.ClientTrust1,
clientVType: CertAndHostVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
expectError: true,
desc: "Bad default hostname check",
clientRoot: cs.ClientTrust1,
clientVerificationType: CertAndHostVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVerificationType: CertAndHostVerification,
expectError: true,
},
// Client side sets vType to CertAndHostVerification, and will do
// default hostname check. Server uses a certificate with "localhost" as
// common name, and will hence pass the default hostname check.
{
desc: "Good default hostname check",
clientRoot: cs.ClientTrust1,
clientVType: CertAndHostVerification,
serverCert: []tls.Certificate{cs.ServerPeerLocalhost1},
serverVType: CertAndHostVerification,
expectError: false,
desc: "Good default hostname check",
clientRoot: cs.ClientTrust1,
clientVerificationType: CertAndHostVerification,
serverCert: []tls.Certificate{cs.ServerPeerLocalhost1},
serverVerificationType: CertAndHostVerification,
expectError: false,
},
} {
test := test
Expand All @@ -769,7 +769,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
Certificates: test.serverCert,
},
RequireClientCert: false,
VType: test.serverVType,
VerificationType: test.serverVerificationType,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
if err != nil {
Expand All @@ -789,7 +789,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
},
VType: test.clientVType,
VerificationType: test.clientVerificationType,
}
clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil {
Expand Down Expand Up @@ -907,7 +907,7 @@ func (s) TestTLSVersions(t *testing.T) {
Certificates: []tls.Certificate{cs.ServerPeerLocalhost1},
},
RequireClientCert: false,
VType: CertAndHostVerification,
VerificationType: CertAndHostVerification,
MinVersion: test.serverMinVersion,
MaxVersion: test.serverMaxVersion,
}
Expand All @@ -929,9 +929,9 @@ func (s) TestTLSVersions(t *testing.T) {
RootOptions: RootCertificateOptions{
RootCACerts: cs.ClientTrust1,
},
VType: CertAndHostVerification,
MinVersion: test.clientMinVersion,
MaxVersion: test.clientMaxVersion,
VerificationType: CertAndHostVerification,
MinVersion: test.clientMinVersion,
MaxVersion: test.clientMaxVersion,
}
clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil {
Expand Down
Loading

0 comments on commit 5fe2e74

Please sign in to comment.