Skip to content

Commit

Permalink
Model predict api (Partial) (OpenCSGs#28)
Browse files Browse the repository at this point in the history
add model inference api (partial implementation, call backend inference service not done)
  • Loading branch information
Rader authored Feb 21, 2024
1 parent 12a8293 commit 53d6253
Show file tree
Hide file tree
Showing 13 changed files with 490 additions and 50 deletions.
42 changes: 42 additions & 0 deletions api/handler/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,48 @@ func (h *ModelHandler) UpdateDownloads(ctx *gin.Context) {
httpbase.OK(ctx, nil)
}

// Predict godoc
// @Security ApiKey
// @Summary Invoke model prediction
// @Description invoke model prediction
// @Tags Model
// @Accept json
// @Produce json
// @Param namespace path string true "namespace"
// @Param name path string true "name"
// @Param body body types.ModelPredictReq true "input for model prediction"
// @Success 200 {object} string "OK"
// @Failure 400 {object} types.APIBadRequest "Bad request"
// @Failure 500 {object} types.APIInternalServerError "Internal server error"
// @Router /models/{namespace}/{name}/predict [post]
func (h *ModelHandler) Predict(ctx *gin.Context) {
var req types.ModelPredictReq
namespace, name, err := common.GetNamespaceAndNameFromContext(ctx)
if err != nil {
slog.Error("Bad request format", "error", err)
httpbase.BadRequest(ctx, err.Error())
return
}

if err := ctx.ShouldBindJSON(&req); err != nil {
slog.Error("Bad request format", "error", err)
httpbase.BadRequest(ctx, err.Error())
return
}

req.Name = name
req.Namespace = namespace

resp, err := h.c.Predict(ctx, &req)
if err != nil {
slog.Error("fail to call predict", slog.String("error", err.Error()))
httpbase.ServerError(ctx, err)
return
}

httpbase.OK(ctx, resp)
}

func parseTagReqs(ctx *gin.Context) (tags []database.TagReq) {
licenseTag := ctx.Query("license_tag")
taskTag := ctx.Query("task_tag")
Expand Down
65 changes: 34 additions & 31 deletions api/router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (

func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) {
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Log())

if enableSwagger {
r.GET("/api/v1/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
Expand Down Expand Up @@ -42,8 +44,6 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) {
}

r.Use(middleware.Authenticator(config))
r.Use(gin.Recovery())
r.Use(middleware.Log())
apiGroup := r.Group("/api/v1")
// TODO:use middleware to handle common response

Expand All @@ -56,33 +56,36 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) {
apiGroup.POST("/list/datasets_by_path", listHandler.ListDatasetsByPath)

// Models routes

apiGroup.POST("/models", modelHandler.Create)
apiGroup.GET("/models", modelHandler.Index)
apiGroup.PUT("/models/:namespace/:name", modelHandler.Update)
apiGroup.DELETE("/models/:namespace/:name", modelHandler.Delete)
apiGroup.GET("/models/:namespace/:name", modelHandler.Show)
apiGroup.GET("/models/:namespace/:name/detail", modelHandler.Detail)
apiGroup.GET("/models/:namespace/:name/branches", modelHandler.Branches)
apiGroup.GET("/models/:namespace/:name/tags", modelHandler.Tags)
apiGroup.GET("/models/:namespace/:name/last_commit", modelHandler.LastCommit)
apiGroup.GET("/models/:namespace/:name/tree", modelHandler.Tree)
apiGroup.GET("/models/:namespace/:name/commits", modelHandler.Commits)
apiGroup.GET("/models/:namespace/:name/raw/*file_path", modelHandler.FileRaw)
// The DownloadFile method differs from the SDKDownload interface in a few ways

// 1.When passing the file_path parameter to the SDKDownload method,
// it only needs to pass the path of the file itself,
// whether it is an lfs file or a non-lfs file.
// The DownloadFile has a different file_path format for lfs files and non-lfs files,
// and an lfs parameter needs to be added.
// 2. DownloadFile returns an object store url for lfs files, while SDKDownload redirects directly.
apiGroup.GET("/models/:namespace/:name/download/*file_path", modelHandler.DownloadFile)
apiGroup.GET("/models/:namespace/:name/resolve/:branch/*file_path", modelHandler.SDKDownload)
apiGroup.POST("/models/:namespace/:name/raw/*file_path", modelHandler.CreateFile)
apiGroup.PUT("/models/:namespace/:name/raw/*file_path", modelHandler.UpdateFile)
apiGroup.POST("/models/:namespace/:name/update_downloads", modelHandler.UpdateDownloads)
apiGroup.POST("/models/:namespace/:name/upload_file", modelHandler.UploadFile)
modelsGroup := apiGroup.Group("/models")
{
modelsGroup.POST("", modelHandler.Create)
modelsGroup.GET("", modelHandler.Index)
modelsGroup.PUT("/:namespace/:name", modelHandler.Update)
modelsGroup.DELETE("/:namespace/:name", modelHandler.Delete)
modelsGroup.GET("/:namespace/:name", modelHandler.Show)
modelsGroup.GET("/:namespace/:name/detail", modelHandler.Detail)
modelsGroup.GET("/:namespace/:name/branches", modelHandler.Branches)
modelsGroup.GET("/:namespace/:name/tags", modelHandler.Tags)
modelsGroup.GET("/:namespace/:name/last_commit", modelHandler.LastCommit)
modelsGroup.GET("/:namespace/:name/tree", modelHandler.Tree)
modelsGroup.GET("/:namespace/:name/commits", modelHandler.Commits)
modelsGroup.GET("/:namespace/:name/raw/*file_path", modelHandler.FileRaw)
// The DownloadFile method differs from the SDKDownload interface in a few ways

// 1.When passing the file_path parameter to the SDKDownload method,
// it only needs to pass the path of the file itself,
// whether it is an lfs file or a non-lfs file.
// The DownloadFile has a different file_path format for lfs files and non-lfs files,
// and an lfs parameter needs to be added.
// 2. DownloadFile returns an object store url for lfs files, while SDKDownload redirects directly.
modelsGroup.GET("/:namespace/:name/download/*file_path", modelHandler.DownloadFile)
modelsGroup.POST("/:namespace/:name/raw/*file_path", modelHandler.CreateFile)
modelsGroup.PUT("/:namespace/:name/raw/*file_path", modelHandler.UpdateFile)
modelsGroup.POST("/:namespace/:name/update_downloads", modelHandler.UpdateDownloads)
modelsGroup.POST("/:namespace/:name/upload_file", modelHandler.UploadFile)
// invoke model endpoint to do pediction
modelsGroup.POST("/:namespace/:name/predict", modelHandler.Predict)
}

// Dataset routes

Expand Down Expand Up @@ -119,8 +122,8 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) {
spaces := apiGroup.Group("/spaces")
{
// list all spaces
spaces.GET("/", spaceHandler.Index)
spaces.POST("/", spaceHandler.Create)
spaces.GET("", spaceHandler.Index)
spaces.POST("", spaceHandler.Create)
// show a user or org's space
spaces.GET("/:namespace/:name", spaceHandler.Get)
spaces.PUT("/:namespace/:name", spaceHandler.Update)
Expand Down
27 changes: 27 additions & 0 deletions builder/inference/app.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package inference

type App interface {
Predict(id ModelID, req *PredictRequest) (*PredictResponse, error)
}

type PredictRequest struct {
Prompt string `json:"prompt"`
}

type PredictResponse struct {
GeneratedText string `json:"generated_text"`
NumInputTokens int `json:"num_input_tokens"`
NumInputTokensBatch int `json:"num_input_tokens_batch"`
NumGeneratedTokens int `json:"num_generated_tokens"`
NumGeneratedTokensBatch int `json:"num_generated_tokens_batch"`
PreprocessingTime float64 `json:"preprocessing_time"`
GenerationTime float64 `json:"generation_time"`
PostprocessingTime float64 `json:"postprocessing_time"`
GenerationTimePerToken float64 `json:"generation_time_per_token"`
GenerationTimePerTokenBatch float64 `json:"generation_time_per_token_batch"`
NumTotalTokens int `json:"num_total_tokens"`
NumTotalTokensBatch int `json:"num_total_tokens_batch"`
TotalTime float64 `json:"total_time"`
TotalTimePerToken float64 `json:"total_time_per_token"`
TotalTimePerTokenBatch float64 `json:"total_time_per_token_batch"`
}
6 changes: 6 additions & 0 deletions builder/inference/init.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package inference

import "opencsg.com/csghub-server/common/config"

func Init(config *config.Config) {
}
124 changes: 124 additions & 0 deletions builder/inference/llm_infer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package inference

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"io"
"net/http"
"time"
)

type ModelID struct {
Owner, Name, Version string
}

func (m ModelID) Hash() uint64 {
f := fnv.New64()
f.Write([]byte(m.Owner))
f.Write([]byte(":"))
f.Write([]byte(m.Name))
f.Write([]byte(":"))
f.Write([]byte(m.Version))
return f.Sum64()
}

var _ App = (*llmInferClient)(nil)

type ModelInfo struct {
Endpoint string
// deploy,running,failed etc
Status string
// ModelID.Hash()
HashID uint64
}

type llmInferClient struct {
lastUpdate time.Time
hc *http.Client
modelServices map[uint64]ModelInfo
serverAddr string
}

func NewInferClient(addr string) App {
hc := http.DefaultClient
hc.Timeout = 5 * time.Second
return &llmInferClient{
hc: hc,
serverAddr: addr,
}
}

func (c *llmInferClient) Predict(id ModelID, req *PredictRequest) (*PredictResponse, error) {
s, err := c.GetModelService(id)
if err != nil {
return nil, fmt.Errorf("failed to get model info,error:%w", err)
}

{
// for test only, as inference service is not ready
if id.Owner == "test_user_name" && id.Name == "test_model_name" {
return &PredictResponse{GeneratedText: "this is a test predict result."}, nil
}
}
return c.CallPredict(s.Endpoint, req)
}

func (c *llmInferClient) ServingList() (map[uint64]ModelInfo, error) {
// use local cache first
if time.Since(c.lastUpdate).Seconds() < 30 {
return c.modelServices, nil
}

tmp := make(map[uint64]ModelInfo)
// TODO:call inference service to ge all serving models
// c.hc.Post()
testModelID := ModelID{
Owner: "test_user_name",
Name: "test_model_name",
Version: "",
}
tmp[testModelID.Hash()] = ModelInfo{
HashID: testModelID.Hash(),
Endpoint: "http://localhost:8080/test_user_name/test_model_name",
Status: "running",
}

c.modelServices = tmp
c.lastUpdate = time.Now()
return c.modelServices, nil
}

func (c *llmInferClient) GetModelService(id ModelID) (ModelInfo, error) {
list, err := c.ServingList()
if err != nil {
return ModelInfo{}, err
}

if s, ok := list[id.Hash()]; ok {
return s, nil
}

return ModelInfo{}, errors.New("model service not found by id")
}

func (c *llmInferClient) CallPredict(url string, req *PredictRequest) (*PredictResponse, error) {
var body bytes.Buffer
json.NewEncoder(&body).Encode(req)
resp, err := c.hc.Post(url, "application/json", &body)
if err != nil {
return nil, fmt.Errorf("failed to send http request,error: %w", err)
}
defer resp.Body.Close()

data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body,error: %w", err)
}

var r PredictResponse
err = json.Unmarshal(data, &r)
return &r, err
}
4 changes: 4 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ type Config struct {
JWT struct {
SigningKey string `envconfig:"STARHUB_JWT_SIGNING_KEY" default:"signing-key"`
}

Inference struct {
ServerAddr string `envconfig:"STARHUB_SERVER_INFERENCE_SERVER_ADDR" default:"http://localhost:8256"`
}
}

func LoadConfig() (cfg *Config, err error) {
Expand Down
13 changes: 13 additions & 0 deletions common/types/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,16 @@ type UpdateDownloadsReq struct {
Date time.Time
CloneCount int64 `json:"download_count"`
}

type ModelPredictReq struct {
Namespace string `json:"-"`
Name string `json:"-"`
Input string `json:"input"`
Version string `json:"version"`
CurrentUser string `json:"current_user"`
}

type ModelPredictResp struct {
Content string `json:"content"`
// TODO:add metrics like tokens, latency etc
}
Loading

0 comments on commit 53d6253

Please sign in to comment.