Skip to content

Commit

Permalink
🔖 chore: Structured output of Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Dec 19, 2024
1 parent 5059920 commit 28d23f3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
53 changes: 53 additions & 0 deletions providers/gemini/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,62 @@ func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq
geminiRequest.Stream = request.Stream
geminiRequest.Model = request.Model

if request.ResponseFormat != nil && (request.ResponseFormat.Type == "json_schema" || request.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"

if request.ResponseFormat.JsonSchema != nil && request.ResponseFormat.JsonSchema.Schema != nil {
cleanedSchema := removeAdditionalProperties(request.ResponseFormat.JsonSchema.Schema)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}

return &geminiRequest, nil
}

func removeAdditionalProperties(schema interface{}) interface{} {
return removeAdditionalPropertiesWithDepth(schema, 0)
}

func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
}

v, ok := schema.(map[string]interface{})
if !ok || len(v) == 0 {
return schema
}

// 如果type不为object和array,则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}

switch v["type"] {
case "object":
delete(v, "additionalProperties")
// 处理 properties
if properties, ok := v["properties"].(map[string]interface{}); ok {
for key, value := range properties {
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
}
}
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := v[field].([]interface{}); ok {
for i, item := range nested {
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
}
}
}
case "array":
if items, ok := v["items"].(map[string]interface{}); ok {
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
}
}

return v
}

func ConvertToChatOpenai(provider base.ProviderInterface, response *GeminiChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
aiError := errorHandle(&response.GeminiErrorResponse)
if aiError != nil {
Expand Down
11 changes: 9 additions & 2 deletions types/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,15 @@ type ChatMessagePart struct {
}

type ChatCompletionResponseFormat struct {
Type string `json:"type,omitempty"`
JsonSchema any `json:"json_schema,omitempty"`
Type string `json:"type,omitempty"`
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
}

type FormatJsonSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema any `json:"schema,omitempty"`
Strict any `json:"strict,omitempty"`
}

type ChatCompletionRequest struct {
Expand Down

0 comments on commit 28d23f3

Please sign in to comment.