Skip to content

Commit

Permalink
Create command 'rosa list user-roles'
Browse files Browse the repository at this point in the history
The new command is implemented similarly to the `rosa list ocm-roles` command.
The code checks if a `role name` contains the `User` infix, and in addition,
checks if the role has the tag `rosa_role_type: User`, because 'user' is a common word.

Related: SDA-5412
  • Loading branch information
oriAdler committed Feb 9, 2022
1 parent 120920d commit c4104a6
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 4 deletions.
2 changes: 2 additions & 0 deletions cmd/list/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/openshift/rosa/cmd/list/region"
"github.com/openshift/rosa/cmd/list/upgrade"
"github.com/openshift/rosa/cmd/list/user"
"github.com/openshift/rosa/cmd/list/userroles"
"github.com/openshift/rosa/cmd/list/version"
"github.com/openshift/rosa/pkg/arguments"
)
Expand All @@ -55,6 +56,7 @@ func init() {
Cmd.AddCommand(instancetypes.Cmd)
Cmd.AddCommand(accountroles.Cmd)
Cmd.AddCommand(ocmroles.Cmd)
Cmd.AddCommand(userroles.Cmd)
flags := Cmd.PersistentFlags()
arguments.AddProfileFlag(flags)
arguments.AddRegionFlag(flags)
Expand Down
2 changes: 1 addition & 1 deletion cmd/list/ocmroles/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func listOCMRoles(awsClient aws.Client, ocmClient *ocm.Client) ([]aws.Role, erro
if err != nil {
return nil, fmt.Errorf("failed to get organization account: %v", err)
}
linkedRoles, err := ocmClient.GetLinkedRoles(orgID)
linkedRoles, err := ocmClient.GetOrganizationLinkedOCMRoles(orgID)
if err != nil {
return nil, err
}
Expand Down
112 changes: 112 additions & 0 deletions cmd/list/userroles/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
Copyright (c) 2022 Red Hat, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package userroles

import (
"fmt"
"os"
"text/tabwriter"

"github.com/openshift/rosa/pkg/aws"
"github.com/openshift/rosa/pkg/helper"
"github.com/openshift/rosa/pkg/logging"
"github.com/openshift/rosa/pkg/ocm"
"github.com/openshift/rosa/pkg/output"
rprtr "github.com/openshift/rosa/pkg/reporter"
"github.com/spf13/cobra"
)

var Cmd = &cobra.Command{
Use: "user-roles",
Aliases: []string{"userrole", "user-role", "userroles", "user-roles"},
Short: "List user roles",
Long: "List user roles for current AWS account",
Example: `# List all user roles
rosa list user-roles`,
Run: run,
Hidden: true,
}

func init() {
output.AddFlag(Cmd)
}

func run(_ *cobra.Command, _ []string) {
reporter := rprtr.CreateReporterOrExit()
logger := logging.CreateLoggerOrExit(reporter)
awsClient := aws.CreateNewClientOrExit(logger, reporter)
ocmClient := ocm.CreateNewClientOrExit(logger, reporter)
defer func() {
err := ocmClient.Close()
if err != nil {
reporter.Errorf("Failed to close OCM connection: %v", err)
}
}()

userRoles, err := listUserRoles(awsClient, ocmClient)
if err != nil {
reporter.Errorf("Failed to get user roles: %v", err)
os.Exit(1)
}

if len(userRoles) == 0 {
reporter.Infof("No user roles available")
os.Exit(0)
}
if output.HasFlag() {
err = output.Print(userRoles)
if err != nil {
reporter.Errorf("%s", err)
os.Exit(1)
}
os.Exit(0)
}

// Create the writer that will be used to print the tabulated results:
writer := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprint(writer, "ROLE NAME\tROLE ARN\tLINKED\n")
for _, userRole := range userRoles {
fmt.Fprintf(writer, "%s\t%s\t%s\n", userRole.RoleName, userRole.RoleARN, userRole.Linked)
}
writer.Flush()
}

func listUserRoles(awsClient aws.Client, ocmClient *ocm.Client) ([]aws.Role, error) {
userRoles, err := awsClient.ListUserRoles()
if err != nil {
return nil, err
}

// Check if roles are linked to account
account, err := ocmClient.GetCurrentAccount()
if err != nil {
return nil, fmt.Errorf("Failed to get Redhat User Account: %v", err)
}
linkedRoles, err := ocmClient.GetAccountLinkedUserRoles(account.ID())
if err != nil {
return nil, err
}

linkedRolesMap := helper.SliceToMap(linkedRoles)
for i := range userRoles {
_, exist := linkedRolesMap[userRoles[i].RoleARN]
if exist {
userRoles[i].Linked = "Yes"
} else {
userRoles[i].Linked = "No"
}
}

return userRoles, nil
}
15 changes: 15 additions & 0 deletions pkg/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

Expand All @@ -43,6 +44,7 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"github.com/openshift/rosa/pkg/reporter"
"github.com/sirupsen/logrus"

"github.com/openshift/rosa/pkg/aws/profile"
Expand Down Expand Up @@ -100,6 +102,7 @@ type Client interface {
HasOpenIDConnectProvider(issuerURL string, accountID string) (bool, error)
FindRoleARNs(roleType string, version string) ([]string, error)
FindPolicyARN(operator Operator, version string) (string, error)
ListUserRoles() ([]Role, error)
ListOCMRoles() ([]Role, error)
ListAccountRoles(version string) ([]Role, error)
GetRoleByARN(roleARN string) (*iam.Role, error)
Expand Down Expand Up @@ -144,6 +147,18 @@ type awsClient struct {
awsAccessKeys *AccessKey
}

func CreateNewClientOrExit(logger *logrus.Logger, reporter *reporter.Object) Client {
awsClient, err := NewClient().
Logger(logger).
Build()
if err != nil {
reporter.Errorf("Failed to create AWS client: %v", err)
os.Exit(1)
}

return awsClient
}

// NewClient creates a builder that can then be used to configure and build a new AWS client.
func NewClient() *ClientBuilder {
return &ClientBuilder{}
Expand Down
43 changes: 42 additions & 1 deletion pkg/aws/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ const (
WorkerAccountRole = "instance_worker"
SupportAccountRole = "support"
OCMRole = "OCM"
OCMUserRole = "User"
)

var AccountRoles map[string]AccountRole = map[string]AccountRole{
Expand All @@ -126,7 +127,6 @@ var AccountRoles map[string]AccountRole = map[string]AccountRole{
SupportAccountRole: {Name: "Support", Flag: "support-role-arn"},
}

var OCMUserRole = "User"
var OCMUserRolePolicyFile = "ocm_user"
var OCMRolePolicyFile = "ocm"
var OCMAdminRolePolicyFile = "ocm_admin"
Expand Down Expand Up @@ -690,6 +690,47 @@ func isOCMRole(roleName *string) bool {
return strings.Contains(aws.StringValue(roleName), fmt.Sprintf("%s-Role", OCMRole))
}

// isUserRole checks the role tags in addition to the role name, because the word 'user' is common
func (c *awsClient) isUserRole(roleName *string) (bool, error) {
if strings.Contains(aws.StringValue(roleName), OCMUserRole) {
roleTags, err := c.iamClient.ListRoleTags(&iam.ListRoleTagsInput{
RoleName: roleName,
})
if err != nil {
return false, err
}

return roleHasTag(roleTags.Tags, tags.RoleType, OCMUserRole), nil
}

return false, nil
}

func (c *awsClient) ListUserRoles() ([]Role, error) {
var userRoles []Role
roles, err := c.ListRoles()
if err != nil {
return nil, err
}

for _, role := range roles {
isUserRole, err := c.isUserRole(role.RoleName)
if err != nil {
return nil, err
}

if isUserRole {
var userRole Role
userRole.RoleName = aws.StringValue(role.RoleName)
userRole.RoleARN = aws.StringValue(role.Arn)

userRoles = append(userRoles, userRole)
}
}

return userRoles, nil
}

func (c *awsClient) ListOCMRoles() ([]Role, error) {
var ocmRoles []Role
roles, err := c.ListRoles()
Expand Down
10 changes: 10 additions & 0 deletions pkg/helper/helper.go → pkg/helper/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ func Contains(s []string, str string) bool {

return false
}

func SliceToMap(s []string) map[string]bool {
m := make(map[string]bool)

for _, v := range s {
m[v] = true
}

return m
}
14 changes: 14 additions & 0 deletions pkg/ocm/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ocm

import (
"fmt"
"os"
"strings"
"time"

Expand All @@ -26,6 +27,7 @@ import (

"github.com/openshift/rosa/pkg/info"
"github.com/openshift/rosa/pkg/logging"
"github.com/openshift/rosa/pkg/reporter"
)

type Client struct {
Expand All @@ -44,6 +46,18 @@ func NewClient() *ClientBuilder {
return &ClientBuilder{}
}

func CreateNewClientOrExit(logger *logrus.Logger, reporter *reporter.Object) *Client {
client, err := NewClient().
Logger(logger).
Build()
if err != nil {
reporter.Errorf("Failed to create OCM connection: %v", err)
os.Exit(1)
}

return client
}

// Logger sets the logger that the connection will use to send messages to the log. This is
// mandatory.
func (b *ClientBuilder) Logger(value *logrus.Logger) *ClientBuilder {
Expand Down
15 changes: 13 additions & 2 deletions pkg/ocm/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ const (
Username = "Username"
URL = "URL"

OCMRoleLabel = "sts_ocm_role"
OCMRoleLabel = "sts_ocm_role"
USERRoleLabel = "sts_user_role"
)

// Regular expression to used to make sure that the identifier or name given by the user is
Expand Down Expand Up @@ -357,7 +358,17 @@ func (c *Client) LinkOrgToRole(orgID string, roleARN string) (bool, error) {
return true, nil
}

func (c *Client) GetLinkedRoles(orgID string) ([]string, error) {
func (c *Client) GetAccountLinkedUserRoles(accountID string) ([]string, error) {
resp, err := c.ocm.AccountsMgmt().V1().Accounts().Account(accountID).
Labels().Labels(USERRoleLabel).Get().Send()
if err != nil && resp.Status() != http.StatusNotFound {
return nil, handleErr(resp.Error(), err)
}

return strings.Split(resp.Body().Value(), ","), nil
}

func (c *Client) GetOrganizationLinkedOCMRoles(orgID string) ([]string, error) {
resp, err := c.ocm.AccountsMgmt().V1().Organizations().Organization(orgID).
Labels().Labels(OCMRoleLabel).Get().Send()
if err != nil && resp.Status() != http.StatusNotFound {
Expand Down

0 comments on commit c4104a6

Please sign in to comment.