Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add options to dry run and debug for chat and generate #8165

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add /template endpoint
  • Loading branch information
ParthSareen committed Dec 18, 2024
commit 1d529d8b7b78ba83337332428f3f438ee887aa29
8 changes: 8 additions & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
return &resp, nil
}

func (c *Client) Template(ctx context.Context, req *TemplateRequest) (*TemplateResponse, error) {
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
var resp TemplateResponse
if err := c.do(ctx, http.MethodPost, "/api/template", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

// CreateBlob creates a blob from a file on the server. digest is the
// expected SHA256 digest of the file, and r represents the file.
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
Expand Down
10 changes: 10 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,16 @@ type CreateRequest struct {
Quantization string `json:"quantization,omitempty"`
}

type TemplateRequest struct {
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
Model string `json:"model"`
Messages []Message `json:"messages"`
Tools []Tool `json:"tools"`
}

type TemplateResponse struct {
TemplatedPrompt string `json:"templated_prompt"`
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
}

// DeleteRequest is the request passed to [Client.Delete].
type DeleteRequest struct {
Model string `json:"model"`
Expand Down
15 changes: 15 additions & 0 deletions server/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,21 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return b.String(), images, nil
}

func applyTemplate(m *Model, msgs []api.Message, tools []api.Tool) (string, error) {
isMllama := checkMllamaModelFamily(m)
for _, msg := range msgs {
if isMllama && len(msg.Images) > 1 {
return "", errTooManyImages
}
}

var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools}); err != nil {
return "", err
}
return b.String(), nil
}

func checkMllamaModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "mllama" {
Expand Down
33 changes: 33 additions & 0 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.PsHandler)
r.Any("/api/template", gin.WrapF(s.TemplateHandler))
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved

// Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
Expand Down Expand Up @@ -1451,6 +1452,38 @@ func (s *Server) PsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}

func (s *Server) TemplateHandler(w http.ResponseWriter, r *http.Request) {
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
var req api.TemplateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
http.Error(w, fmt.Sprintf("model '%s' not found", req.Model), http.StatusNotFound)
case err.Error() == "invalid model name":
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
http.Error(w, err.Error(), http.StatusBadRequest)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}

prompt, err := applyTemplate(model, req.Messages, req.Tools)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if err := json.NewEncoder(w).Encode(api.TemplateResponse{TemplatedPrompt: prompt}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
ParthSareen marked this conversation as resolved.
Show resolved Hide resolved
return
}
}

func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()

Expand Down
Loading