Skip to content

Commit

Permalink
Refactor Neo API to enhance CORS handling and streamline endpoint reg…
Browse files Browse the repository at this point in the history
…istration. Introduce OPTIONS handlers for all endpoints, improve CORS middleware logic, and reorganize assistant creation flow in load.go. Additionally, add a delay in the Answer method to ensure proper retrieval of assistant and chat IDs.
  • Loading branch information
trheyi committed Dec 14, 2024
1 parent 79ee5b7 commit 6c244d8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 43 deletions.
103 changes: 64 additions & 39 deletions neo/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ func (neo *DSL) API(router *gin.Engine, path string) error {
return err
}

// Cross-Domain handlers
cors, err := neo.getCorsHandlers(router, path)
if err != nil {
return err
}

// Append cors handlers
middlewares = append(middlewares, cors...)
// Register OPTIONS handlers for all endpoints
router.OPTIONS(path, neo.optionsHandler)
router.OPTIONS(path+"/status", neo.optionsHandler)
router.OPTIONS(path+"/chats", neo.optionsHandler)
router.OPTIONS(path+"/history", neo.optionsHandler)

// Register chat endpoint
// Register endpoints with middlewares
router.GET(path, append(middlewares, neo.handleChat)...)
router.POST(path, append(middlewares, neo.handleChat)...)

// Register chat list endpoint
router.GET(path+"/status", append(middlewares, neo.handleStatus)...)
router.GET(path+"/chats", append(middlewares, neo.handleChatList)...)

// Register chat history endpoint
router.GET(path+"/history", append(middlewares, neo.handleChatHistory)...)

return nil
}

// handleStatus handles the status request
func (neo *DSL) handleStatus(c *gin.Context) {
c.Status(200)
c.Done()
}

// handleChat handles the chat request
func (neo *DSL) handleChat(c *gin.Context) {
// Set headers for SSE
Expand Down Expand Up @@ -112,7 +112,7 @@ func (neo *DSL) handleChatHistory(c *gin.Context) {
}

// getCorsHandlers returns CORS middleware handlers
func (neo *DSL) getCorsHandlers(router *gin.Engine, path string) ([]gin.HandlerFunc, error) {
func (neo *DSL) getCorsHandlers() ([]gin.HandlerFunc, error) {
if len(neo.Allows) == 0 {
return []gin.HandlerFunc{}, nil
}
Expand All @@ -124,66 +124,91 @@ func (neo *DSL) getCorsHandlers(router *gin.Engine, path string) ([]gin.HandlerF
allowsMap[allow] = true
}

router.OPTIONS(path+"/history", neo.optionsHandler)
router.OPTIONS(path+"/commands", neo.optionsHandler)
return []gin.HandlerFunc{neo.corsMiddleware(allowsMap)}, nil
}

// corsMiddleware handles CORS requests
func (neo *DSL) corsMiddleware(allowsMap map[string]bool) gin.HandlerFunc {
return func(c *gin.Context) {
referer := neo.getOrigin(c)
if referer != "" {
if !api.IsAllowed(c, allowsMap) {
c.JSON(403, gin.H{"message": referer + " not allowed", "code": 403})
c.Abort()
return
}
url, _ := url.Parse(referer)
referer = fmt.Sprintf("%s://%s", url.Scheme, url.Host)
c.Writer.Header().Set("Access-Control-Allow-Origin", referer)
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
origin := neo.getOrigin(c)
if origin == "" {
c.Next()
return
}

// Check if origin is allowed
if !api.IsAllowed(c, allowsMap) {
c.AbortWithStatusJSON(403, gin.H{
"message": origin + " not allowed",
"code": 403,
})
return
}

// Set CORS headers
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, Accept, Origin, Cache-Control, X-Requested-With")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")

if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}

c.Next()
}
}

// optionsHandler handles OPTIONS requests
func (neo *DSL) optionsHandler(c *gin.Context) {
origin := neo.getOrigin(c)
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
if origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400") // 24 hours
}
c.AbortWithStatus(204)
}

// getOrigin returns the request origin
func (neo *DSL) getOrigin(c *gin.Context) string {
referer := c.Request.Referer()
origin := c.Request.Header.Get("Origin")
if origin == "" {
origin = referer
origin = c.Request.Referer()
if origin != "" {
if u, err := url.Parse(origin); err == nil {
origin = fmt.Sprintf("%s://%s", u.Scheme, u.Host)
}
}
}
return origin
}

// getGuardHandlers returns authentication middleware handlers
func (neo *DSL) getGuardHandlers() ([]gin.HandlerFunc, error) {

// Cross-Domain handlers
cors, err := neo.getCorsHandlers()
if err != nil {
return nil, err
}

if neo.Guard == "" {
return []gin.HandlerFunc{neo.defaultGuard}, nil
middlewares := append(cors, neo.defaultGuard)
return middlewares, nil
}

// Validate the custom guard
_, err := process.Of(neo.Guard)
_, err = process.Of(neo.Guard)
if err != nil {
return nil, err
}

// Return custom guard
return []gin.HandlerFunc{api.ProcessGuard(neo.Guard)}, nil
middlewares := append(cors, api.ProcessGuard(neo.Guard, cors...))
return middlewares, nil
}

// defaultGuard is the default authentication handler
Expand Down
8 changes: 4 additions & 4 deletions neo/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ func Load(cfg config.Config) error {

Neo = &setting

// Create Default Assistant
Neo.Assistant, err = Neo.createDefaultAssistant()
// Conversation Setting
err = Neo.createConversation()
if err != nil {
return err
}

// Conversation Setting
err = Neo.createConversation()
// Create Default Assistant
Neo.Assistant, err = Neo.createDefaultAssistant()
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions neo/neo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"strings"
"sync"
"time"

"github.com/fatih/color"
"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -36,6 +37,9 @@ func (neo *DSL) Answer(ctx Context, question string, c *gin.Context) error {
return err
}

// Get the assistant_id, chat_id
time.Sleep(1 * time.Second)

// Send a text message to the client
msg := message.New().Map(map[string]interface{}{
"text": "Hello, world!",
Expand Down

0 comments on commit 6c244d8

Please sign in to comment.