Skip to content

Commit

Permalink
[add] yao studio api jwt auth support
Browse files Browse the repository at this point in the history
  • Loading branch information
trheyi committed Oct 14, 2022
1 parent 25bd1f3 commit 9fb8ef2
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 24 deletions.
14 changes: 14 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"errors"
"os"
"path/filepath"
"strings"

"github.com/caarlos0/env/v6"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/joho/godotenv"
"github.com/yaoapp/kun/exception"
"github.com/yaoapp/kun/log"
"github.com/yaoapp/yao/crypto"
)

// Conf 配置参数
Expand Down Expand Up @@ -54,6 +57,17 @@ func Load() Config {
exception.New("Can't read config %s", 500, err.Error()).Throw()
}
cfg.Root, _ = filepath.Abs(cfg.Root)

// Studio Secret
if cfg.Studio.Secret == nil {
v, err := crypto.Hash(crypto.HashTypes["SHA256"], uuid.New().String())
if err != nil {
exception.New("Can't gengrate studio secret %s", 500, err.Error()).Throw()
}
cfg.Studio.Secret = []byte(strings.ToUpper(v))
cfg.Studio.Auto = true
}

return cfg
}

Expand Down
5 changes: 3 additions & 2 deletions config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ type Config struct {

// StudioConfig the studio config
type StudioConfig struct {
Port int `json:"studio_port,omitempty" env:"YAO_STUDIO_PORT" envDefault:"5077"` // Studio port
Secret int `json:"studio_secret,omitempty" env:"YAO_STUDIO_SECRET"` // Studio Secret, if does not set, auto-generate a secret
Port int `json:"studio_port,omitempty" env:"YAO_STUDIO_PORT" envDefault:"5077"` // Studio port
Secret []byte `json:"studio_secret,omitempty" env:"YAO_STUDIO_SECRET"` // Studio Secret, if does not set, auto-generate a secret
Auto bool `json:"-"`
}

// DBConfig 数据库配置
Expand Down
3 changes: 2 additions & 1 deletion crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ func init() {
crypto.RegisterHash(crypto.MD4, md4.New)
}

var hashTypes = map[string]crypto.Hash{
// HashTypes string
var HashTypes = map[string]crypto.Hash{
"MD4": crypto.MD4,
"MD5": crypto.MD5,
"SHA1": crypto.SHA1,
Expand Down
4 changes: 2 additions & 2 deletions crypto/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func ProcessHash(process *gou.Process) interface{} {
typ := process.ArgsString(0)
value := process.ArgsString(1)

h, has := hashTypes[typ]
h, has := HashTypes[typ]
if !has {
exception.New("%s does not support", 400, typ).Throw()
}
Expand All @@ -41,7 +41,7 @@ func ProcessHmac(process *gou.Process) interface{} {
value := process.ArgsString(1)
key := process.ArgsString(2)

h, has := hashTypes[typ]
h, has := HashTypes[typ]
if !has {
exception.New("%s does not support", 400, typ).Throw()
}
Expand Down
20 changes: 16 additions & 4 deletions helper/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ type JwtToken struct {
}

// JwtValidate JWT 校验
func JwtValidate(tokenString string) *JwtClaims {
func JwtValidate(tokenString string, secret ...[]byte) *JwtClaims {

jwtSecret := []byte(config.Conf.JWTSecret)
if len(secret) > 0 {
jwtSecret = secret[0]
}

token, err := jwt.ParseWithClaims(tokenString, &JwtClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(config.Conf.JWTSecret), nil
return jwtSecret, nil
})

if err != nil {
Expand All @@ -49,7 +55,13 @@ func JwtValidate(tokenString string) *JwtClaims {

// JwtMake 生成 JWT
// option: {"subject":"<主题>", "audience": "<接收人>", "issuer":"<签发人>", "timeout": "<有效期,单位秒>", "sid":"<会话ID>"}
func JwtMake(id int, data map[string]interface{}, option map[string]interface{}) JwtToken {
func JwtMake(id int, data map[string]interface{}, option map[string]interface{}, secret ...[]byte) JwtToken {

jwtSecret := []byte(config.Conf.JWTSecret)
if len(secret) > 0 {
jwtSecret = secret[0]
}

now := time.Now().Unix()
sid := ""
timeout := int64(3600)
Expand Down Expand Up @@ -96,7 +108,7 @@ func JwtMake(id int, data map[string]interface{}, option map[string]interface{})
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(config.Conf.JWTSecret))
tokenString, err := token.SignedString([]byte(jwtSecret))
if err != nil {
exception.New("生成令牌失败", 500).Ctx(err).Throw()
}
Expand Down
10 changes: 8 additions & 2 deletions importer/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ func Open(name string) from.Source {
return nil
}

// WithSid attch sid
func (imp *Importer) WithSid(sid string) *Importer {
imp.Sid = sid
return imp
}

// AutoMapping 根据文件信息获取字段映射表
func (imp *Importer) AutoMapping(src from.Source) *Mapping {
sourceColumns := getSourceColumns(src)
Expand Down Expand Up @@ -405,7 +411,7 @@ func (imp *Importer) Run(src from.Source, mapping *Mapping) interface{} {
return
}

response, err := process.Exec()
response, err := process.WithSID(imp.Sid).Exec()
if err != nil {
failed = failed + length
log.With(log.F{"line": line}).Error("导入失败: %s", err.Error())
Expand Down Expand Up @@ -437,7 +443,7 @@ func (imp *Importer) Run(src from.Source, mapping *Mapping) interface{} {
}

if imp.Output != "" {
res, err := gou.NewProcess(imp.Output, output).Exec()
res, err := gou.NewProcess(imp.Output, output).WithSID(imp.Sid).Exec()
if err != nil {
log.With(log.F{"output": imp.Output}).Error(err.Error())
return output
Expand Down
12 changes: 6 additions & 6 deletions importer/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func init() {
func ProcessRun(process *gou.Process) interface{} {
process.ValidateArgNums(3)
name := process.ArgsString(0)
imp := Select(name)
imp := Select(name).WithSid(process.Sid)
filename := process.ArgsString(1)
src := Open(filename)
defer src.Close()
Expand All @@ -34,7 +34,7 @@ func ProcessRun(process *gou.Process) interface{} {
func ProcessSetting(process *gou.Process) interface{} {
process.ValidateArgNums(1)
name := process.ArgsString(0)
imp := Select(name)
imp := Select(name).WithSid(process.Sid)
return map[string]interface{}{
"mappingPreview": imp.Option.MappingPreview,
"dataPreview": imp.Option.DataPreview,
Expand All @@ -48,7 +48,7 @@ func ProcessSetting(process *gou.Process) interface{} {
func ProcessData(process *gou.Process) interface{} {
process.ValidateArgNums(5)
name := process.ArgsString(0)
imp := Select(name)
imp := Select(name).WithSid(process.Sid)

filename := process.ArgsString(1)
src := Open(filename)
Expand All @@ -66,7 +66,7 @@ func ProcessData(process *gou.Process) interface{} {
func ProcessDataSetting(process *gou.Process) interface{} {
process.ValidateArgNums(1)
name := process.ArgsString(0)
imp := Select(name)
imp := Select(name).WithSid(process.Sid)
return imp.DataSetting()
}

Expand All @@ -75,7 +75,7 @@ func ProcessDataSetting(process *gou.Process) interface{} {
func ProcessMapping(process *gou.Process) interface{} {
process.ValidateArgNums(2)
name := process.ArgsString(0)
imp := Select(name)
imp := Select(name).WithSid(process.Sid)

filename := process.ArgsString(1)
src := Open(filename)
Expand All @@ -88,7 +88,7 @@ func ProcessMapping(process *gou.Process) interface{} {
func ProcessMappingSetting(process *gou.Process) interface{} {
process.ValidateArgNums(2)
name := process.ArgsString(0)
imp := Select(name)
imp := Select(name).WithSid(process.Sid)

filename := process.ArgsString(1)
src := Open(filename)
Expand Down
1 change: 1 addition & 0 deletions importer/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Importer struct {
Columns []Column `json:"columns"` // 字段列表
Option Option `json:"option,omitempty"` // 导入配置项
Rules map[string]string `json:"rules,omitempty"` // 许可导入规则
Sid string `json:"-"` // sid
}

// Column 导入字段定义
Expand Down
39 changes: 39 additions & 0 deletions studio/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package studio
import (
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/yaoapp/kun/exception"
"github.com/yaoapp/kun/log"
"github.com/yaoapp/xun"
"github.com/yaoapp/yao/config"
"github.com/yaoapp/yao/helper"
)

// hdRecovered custom recovered
func hdRecovered(c *gin.Context, recovered interface{}) {

var code = http.StatusInternalServerError
Expand Down Expand Up @@ -39,3 +44,37 @@ func hdRecovered(c *gin.Context, recovered interface{}) {

c.AbortWithStatus(code)
}

// cross domian
func hdCrossDomain(c *gin.Context) {
}

// studio API Auth
func hdAuth(c *gin.Context) {

tokenString := c.Request.Header.Get("Authorization")
if strings.HasPrefix(tokenString, "Bearer") {
tokenString = strings.TrimSpace(strings.TrimPrefix(tokenString, "Bearer "))
if tokenString == "" {
c.JSON(403, gin.H{"code": 403, "message": "No permission"})
c.Abort()
return
}

claims := helper.JwtValidate(tokenString, config.Conf.Studio.Secret)
c.Set("__sid", claims.SID)
return

} else if strings.HasPrefix(tokenString, "Signature ") { // For Yao Studio
signature := strings.TrimSpace(strings.TrimPrefix(tokenString, "Signature "))
nonce := c.Request.Header.Get("Studio-Nonce")
ts := c.Request.Header.Get("Studio-Timestamp")
query := c.Request.URL.Query()
log.Trace("[Studio] %s, %s %s %v", signature, nonce, ts, query)
return
}

c.JSON(403, gin.H{"code": 403, "message": "No permission"})
c.Abort()
return
}
9 changes: 7 additions & 2 deletions studio/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var regExcp = regexp.MustCompile("^Exception\\|([0-9]+):(.+)$")
// Serve start the api server
func setRouter(router *gin.Engine) {

router.Use(gin.CustomRecovery(hdRecovered))
router.Use(gin.CustomRecovery(hdRecovered), hdAuth)

// DSL ReadDir, ReadFile
router.GET("/dsl/:method", func(c *gin.Context) {
Expand Down Expand Up @@ -202,7 +202,12 @@ func setRouter(router *gin.Engine) {
return
}

res, err := gou.Yao.Engine.Call(map[string]interface{}{}, service, fun.Method, fun.Args...)
req := gou.Yao.New(service, fun.Method)
if sid, has := c.Get("__sid"); has {
req.WithSid(fmt.Sprintf("%s", sid))
}

res, err := req.Call(fun.Args...)
if err != nil {
// parse Exception
code := 500
Expand Down
34 changes: 29 additions & 5 deletions studio/studio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/yaoapp/gou"
"github.com/yaoapp/yao/config"
"github.com/yaoapp/yao/helper"
)

type kv map[string]interface{}
Expand Down Expand Up @@ -57,7 +58,7 @@ func TestStartStopError(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}

func TestGetAPI(t *testing.T) {
func TestAPI(t *testing.T) {

Load(config.Conf)

Expand Down Expand Up @@ -132,7 +133,16 @@ func httpGet[T kv | arr | interface{} | map[string]interface{} | int | []interfa

var data T
url = fmt.Sprintf("http://127.0.0.1:%d%s", config.Conf.Studio.Port, url)
res, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Fatal(err)
}

token := getToken(t)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

client := http.Client{}
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
Expand All @@ -143,10 +153,10 @@ func httpGet[T kv | arr | interface{} | map[string]interface{} | int | []interfa
t.Fatal(err)
}

if body != nil {
if body != nil && len(body) > 0 {
err = jsoniter.Unmarshal(body, &data)
if err != nil {
t.Fatal(err)
t.Fatal(fmt.Sprintf("%s\n%s\n", err.Error(), string(body)))
}
}
}
Expand All @@ -164,11 +174,16 @@ func httpPost[T kv | arr | interface{} | map[string]interface{} | int | []interf
}

url = fmt.Sprintf("http://127.0.0.1:%d%s", config.Conf.Studio.Port, url)
res, err := http.Post(url, "application/json", buff)
req, err := http.NewRequest("POST", url, buff)
if err != nil {
t.Fatal(err)
}

token := getToken(t)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

client := http.Client{}
res, err := client.Do(req)
if res.Body != nil {
body, err := io.ReadAll(res.Body)
if err != nil {
Expand Down Expand Up @@ -197,3 +212,12 @@ func httpPostJSON[T kv | arr | interface{} | map[string]interface{} | int | []in
}
return httpPost[T](url, data, t)
}

func getToken(t *testing.T) string {
return helper.JwtMake(
1,
map[string]interface{}{"id": 1, "user_id": 1, "user": kv{"id": 1, "name": "test"}},
map[string]interface{}{"issuer": "unit-test", "timeout": 3600},
config.Conf.Studio.Secret,
).Token
}

0 comments on commit 9fb8ef2

Please sign in to comment.