Skip to content

Commit

Permalink
using IMDSv2 instance metadata (#56)
Browse files Browse the repository at this point in the history
* using IMDSv2 instance metadata

* applying suggestions -- defer before reads / handle errors on NewRequest()

* Falling back to IMDSv1 if token API not available
  • Loading branch information
swbsf authored Sep 13, 2022
1 parent 4cbf72c commit 9d20ce5
Showing 1 changed file with 46 additions and 24 deletions.
70 changes: 46 additions & 24 deletions pkg/provider/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"path"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -21,6 +22,11 @@ const (
// Retrieve instance metadata for AWS EC2 instance
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
instanceMetadataEndpoint = "http://169.254.169.254/latest/meta-data"

// IMDSv2 token related constants
tokenEndpoint = "http://169.254.169.254/latest/api/token"
tokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
tokenRequestHeader = "X-aws-ec2-metadata-token"
)

// The VPC identifier
Expand Down Expand Up @@ -60,40 +66,71 @@ func NewProvider() (provider.Provider, error) {
}, nil
}

func retrieveInstanceNetworkInterfaceMacAddress() (string, error) {
res, err := http.Get(instanceMetadataEndpoint + "/mac")
func getV2Token(client http.Client) (string, error) {
req, err := http.NewRequest(http.MethodPut, tokenEndpoint, nil)
if err != nil {
return "", err
}
req.Header.Set(tokenTTLHeader, "21600")
res, err := client.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)

token, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}

return string(token), nil
}

func retrieveInstanceMetadata(client http.Client, contextPath string, token string) (string, error) {
req, err := http.NewRequest(http.MethodGet, instanceMetadataEndpoint+contextPath, nil)
if err != nil {
return "", err
}

if token != "" {
req.Header.Set(tokenRequestHeader, token)
}
res, err := client.Do(req)
if err != nil {
return "", err
}

defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
return string(body), nil
}

func retrieveVPCID() (string, error) {
if vpcID != "" {
return vpcID, nil
}
mac, err := retrieveInstanceNetworkInterfaceMacAddress()

client := http.Client{Timeout: 3 * time.Second}

token, err := getV2Token(client)
if err != nil {
return "", err
fmt.Printf("failed getting IMDSv2 token falling back to IMDSv1 : %s", err)
}

res, err := http.Get(fmt.Sprintf(instanceMetadataEndpoint + "/network/interfaces/macs/" + mac + "/vpc-id"))
mac, err := retrieveInstanceMetadata(client, "/mac", string(token))
if err != nil {
return "", err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)

body, err := retrieveInstanceMetadata(client, "/network/interfaces/macs/"+mac+"/vpc-id", string(token))
if err != nil {
return "", err
}

return string(body), nil
return body, nil
}

func (p *awsProvider) GetInstanceID(node corev1.Node) string {
Expand All @@ -109,7 +146,6 @@ func (p *awsProvider) GetInstance(ctx context.Context, instanceID string) (*prov
res, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice([]string{instanceID}),
})

if err != nil {
return nil, converter.DecodeEC2Error("failed to get instance", err)
}
Expand Down Expand Up @@ -137,7 +173,6 @@ func (p *awsProvider) GetAddress(ctx context.Context, addressID string) (*provid
},
},
})

if err != nil {
return nil, converter.DecodeEC2Error("failed to get address", err)
}
Expand All @@ -156,7 +191,6 @@ func (p *awsProvider) CreateAddress(ctx context.Context) (*provider.Address, err
res, err := p.ec2.AllocateAddress(&ec2.AllocateAddressInput{
Domain: aws.String("vpc"),
})

if err != nil {
return nil, converter.DecodeEC2Error("failed to create address", err)
}
Expand All @@ -168,7 +202,6 @@ func (p *awsProvider) DeleteAddress(ctx context.Context, addressID string) error
_, err := p.ec2.ReleaseAddress(&ec2.ReleaseAddressInput{
AllocationId: aws.String(addressID),
})

if err != nil {
return converter.DecodeEC2Error("failed to delete address", err)
}
Expand All @@ -181,7 +214,6 @@ func (p *awsProvider) AssociateAddress(ctx context.Context, req provider.Associa
AllocationId: aws.String(req.AddressID),
NetworkInterfaceId: aws.String(req.NetworkInterfaceID),
})

if err != nil {
return converter.DecodeEC2Error("failed to associate address", err)
}
Expand All @@ -193,7 +225,6 @@ func (p *awsProvider) DisassociateAddress(ctx context.Context, req provider.Disa
_, err := p.ec2.DisassociateAddress(&ec2.DisassociateAddressInput{
AssociationId: aws.String(req.AssociationID),
})

if err != nil {
return converter.DecodeEC2Error("failed to disassociate address", err)
}
Expand All @@ -205,7 +236,6 @@ func (p *awsProvider) getSecurityGroup(ctx context.Context, firewallRuleID strin
res, err := p.ec2.DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
GroupIds: aws.StringSlice([]string{firewallRuleID}),
})

if err != nil {
return nil, converter.DecodeEC2Error("failed to get security group", err)
}
Expand Down Expand Up @@ -238,7 +268,6 @@ func (p *awsProvider) CreateFirewallRuleGroup(ctx context.Context, req provider.
GroupName: aws.String(req.Name),
VpcId: aws.String(vpcID),
})

if err != nil {
return "", converter.DecodeEC2Error("failed to create security group", err)
}
Expand Down Expand Up @@ -275,7 +304,6 @@ func (p *awsProvider) DeleteFirewallRule(ctx context.Context, firewallRuleID str
_, err := p.ec2.DeleteSecurityGroup(&ec2.DeleteSecurityGroupInput{
GroupId: aws.String(firewallRuleID),
})

if err != nil {
return converter.DecodeEC2Error("failed to delete security group", err)
}
Expand All @@ -290,7 +318,6 @@ func (p *awsProvider) authorizeSecurityGroupIngress(ctx context.Context, firewal
converter.EncodeIPPermission(req),
},
})

if err != nil {
return converter.DecodeEC2Error("failed to authorize security group ingress permission", err)
}
Expand All @@ -305,7 +332,6 @@ func (p *awsProvider) revokeSecurityGroupIngress(ctx context.Context, firewallRu
converter.EncodeIPPermission(req),
},
})

if err != nil {
return converter.DecodeEC2Error("failed to revoke security group ingress permission", err)
}
Expand All @@ -320,7 +346,6 @@ func (p *awsProvider) authorizeSecurityGroupEgress(ctx context.Context, firewall
converter.EncodeIPPermission(req),
},
})

if err != nil {
return converter.DecodeEC2Error("failed to authorize security group egress permission", err)
}
Expand All @@ -335,7 +360,6 @@ func (p *awsProvider) revokeSecurityGroupEgress(ctx context.Context, firewallRul
converter.EncodeIPPermission(req),
},
})

if err != nil {
return converter.DecodeEC2Error("failed to revoke security group egress permission", err)
}
Expand All @@ -347,7 +371,6 @@ func (p *awsProvider) AssociateFirewallRule(ctx context.Context, req provider.As
res, err := p.ec2.DescribeNetworkInterfaces(&ec2.DescribeNetworkInterfacesInput{
NetworkInterfaceIds: aws.StringSlice([]string{req.NetworkInterfaceID}),
})

if err != nil {
return err
}
Expand Down Expand Up @@ -380,7 +403,6 @@ func (p *awsProvider) DisassociateFirewallRule(ctx context.Context, req provider
res, err := p.ec2.DescribeNetworkInterfaces(&ec2.DescribeNetworkInterfacesInput{
NetworkInterfaceIds: aws.StringSlice([]string{req.NetworkInterfaceID}),
})

if err != nil {
return converter.DecodeEC2Error("failed to disassociate security group", err)
}
Expand Down

0 comments on commit 9d20ce5

Please sign in to comment.