Skip to content

Commit

Permalink
Merge pull request ConnectAI-E#264 from wenerme/wener
Browse files Browse the repository at this point in the history
feat: default to dall-e-3, support more image res, add image style
  • Loading branch information
Leizhenpeng authored Nov 16, 2023
2 parents 05b12e3 + 3b5fb83 commit 1847020
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 21 deletions.
18 changes: 17 additions & 1 deletion code/handlers/card_pic_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ func NewPicResolutionHandler(cardMsg CardMsg, m MessageHandler) CardHandlerFunc
CommonProcessPicResolution(cardMsg, cardAction, m.sessionCache)
return nil, nil
}
if cardMsg.Kind == PicStyleKind {
CommonProcessPicStyle(cardMsg, cardAction, m.sessionCache)
return nil, nil
}
return nil, ErrNextHandler
}
}
Expand Down Expand Up @@ -57,13 +61,25 @@ func CommonProcessPicResolution(msg CardMsg,
&msg.MsgId)
}

func CommonProcessPicStyle(msg CardMsg,
cardAction *larkcard.CardAction,
cache services.SessionServiceCacheInterface) {
option := cardAction.Action.Option
fmt.Println(larkcore.Prettify(msg))
cache.SetPicStyle(msg.SessionId, services.PicStyle(option))
//send text
replyMsg(context.Background(), "已更新图片风格为"+option,
&msg.MsgId)
}

func (m MessageHandler) CommonProcessPicMore(msg CardMsg) {
resolution := m.sessionCache.GetPicResolution(msg.SessionId)
style := m.sessionCache.GetPicStyle(msg.SessionId)

logger.Debugf("resolution: %v", resolution)
logger.Debug("msg: %v", msg)
question := msg.Value.(string)
bs64, _ := m.gpt.GenerateOneImage(question, resolution)
bs64, _ := m.gpt.GenerateOneImage(question, resolution, style)
replayImageCardByBase64(context.Background(), bs64, &msg.MsgId,
&msg.SessionId, question)
}
Expand Down
6 changes: 4 additions & 2 deletions code/handlers/event_pic_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (*PicAction) Execute(a *ActionInfo) bool {
a.handler.sessionCache.SetMode(*a.info.sessionId,
services.ModePicCreate)
a.handler.sessionCache.SetPicResolution(*a.info.sessionId,
services.Resolution256)
services.Resolution1024)
sendPicCreateInstructionCard(*a.ctx, a.info.sessionId,
a.info.msgId)
return false
Expand Down Expand Up @@ -92,8 +92,10 @@ func (*PicAction) Execute(a *ActionInfo) bool {
if mode == services.ModePicCreate {
resolution := a.handler.sessionCache.GetPicResolution(*a.
info.sessionId)
style := a.handler.sessionCache.GetPicStyle(*a.
info.sessionId)
bs64, err := a.handler.gpt.GenerateOneImage(a.info.qParsed,
resolution)
resolution, style)
if err != nil {
replyMsg(*a.ctx, fmt.Sprintf(
"🤖️:图片生成失败,请稍后再试~\n错误信息: %v", err), a.info.msgId)
Expand Down
44 changes: 36 additions & 8 deletions code/handlers/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var (
ClearCardKind = CardKind("clear") // 清空上下文
PicModeChangeKind = CardKind("pic_mode_change") // 切换图片创作模式
PicResolutionKind = CardKind("pic_resolution") // 图片分辨率调整
PicStyleKind = CardKind("pic_style") // 图片风格调整
PicTextMoreKind = CardKind("pic_text_more") // 重新根据文本生成图片
PicVarMoreKind = CardKind("pic_var_more") // 变量图片
RoleTagsChooseKind = CardKind("role_tags_choose") // 内置角色所属标签选择
Expand Down Expand Up @@ -324,29 +325,56 @@ func withOneBtn(btn *larkcard.MessageCardEmbedButton) larkcard.

func withPicResolutionBtn(sessionID *string) larkcard.
MessageCardElement {
cancelMenu := newMenu("默认分辨率",
resolutionMenu := newMenu("默认分辨率",
map[string]interface{}{
"value": "0",
"kind": PicResolutionKind,
"sessionId": *sessionID,
"msgId": *sessionID,
},
// dall-e-2 256, 512, 1024
//MenuOption{
// label: "256x256",
// value: string(services.Resolution256),
//},
//MenuOption{
// label: "512x512",
// value: string(services.Resolution512),
//},
// dall-e-3
MenuOption{
label: "256x256",
value: string(services.Resolution256),
label: "1024x1024",
value: string(services.Resolution1024),
},
MenuOption{
label: "512x512",
value: string(services.Resolution512),
label: "1024x1792",
value: string(services.Resolution10241792),
},
MenuOption{
label: "1024x1024",
value: string(services.Resolution1024),
label: "1792x1024",
value: string(services.Resolution17921024),
},
)

styleMenu := newMenu("风格",
map[string]interface{}{
"value": "0",
"kind": PicStyleKind,
"sessionId": *sessionID,
"msgId": *sessionID,
},
MenuOption{
label: "生动风格",
value: string(services.PicStyleVivid),
},
MenuOption{
label: "自然风格",
value: string(services.PicStyleNatural),
},
)

actions := larkcard.NewMessageCardAction().
Actions([]larkcard.MessageCardActionElement{cancelMenu}).
Actions([]larkcard.MessageCardActionElement{resolutionMenu, styleMenu}).
Layout(larkcard.MessageCardActionLayoutFlow.Ptr()).
Build()
return actions
Expand Down
2 changes: 1 addition & 1 deletion code/services/openai/gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestGenerateOneImage(t *testing.T) {
gpt := NewChatGPT(*config)
prompt := "a red apple"
size := "256x256"
imageURL, err := gpt.GenerateOneImage(prompt, size)
imageURL, err := gpt.GenerateOneImage(prompt, size, "")
if err != nil {
t.Errorf("TestGenerateOneImage failed with error: %v", err)
}
Expand Down
13 changes: 9 additions & 4 deletions code/services/openai/picture.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type ImageGenerationRequestBody struct {
N int `json:"n"`
Size string `json:"size"`
ResponseFormat string `json:"response_format"`
Model string `json:"model,omitempty"`
Style string `json:"style,omitempty"`
}

type ImageResponseBody struct {
Expand All @@ -33,12 +35,14 @@ type ImageVariantRequestBody struct {
}

func (gpt *ChatGPT) GenerateImage(prompt string, size string,
n int) ([]string, error) {
n int, style string) ([]string, error) {
requestBody := ImageGenerationRequestBody{
Prompt: prompt,
N: n,
Size: size,
ResponseFormat: "b64_json",
Model: "dall-e-3",
Style: style,
}

imageResponseBody := &ImageResponseBody{}
Expand All @@ -57,8 +61,8 @@ func (gpt *ChatGPT) GenerateImage(prompt string, size string,
}

func (gpt *ChatGPT) GenerateOneImage(prompt string,
size string) (string, error) {
b64s, err := gpt.GenerateImage(prompt, size, 1)
size string, style string) (string, error) {
b64s, err := gpt.GenerateImage(prompt, size, 1, style)
if err != nil {
return "", err
}
Expand All @@ -67,7 +71,8 @@ func (gpt *ChatGPT) GenerateOneImage(prompt string,

func (gpt *ChatGPT) GenerateOneImageWithDefaultSize(
prompt string) (string, error) {
return gpt.GenerateOneImage(prompt, "512x512")
// works for dall-e 2&3
return gpt.GenerateOneImage(prompt, "1024x1024", "")
}

func (gpt *ChatGPT) GenerateImageVariation(images string,
Expand Down
49 changes: 44 additions & 5 deletions code/services/sessionCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ type SessionService struct {
}
type PicSetting struct {
resolution Resolution
style PicStyle
}
type Resolution string
type PicStyle string

type SessionMeta struct {
Mode SessionMode `json:"mode"`
Expand All @@ -24,9 +26,15 @@ type SessionMeta struct {
}

const (
Resolution256 Resolution = "256x256"
Resolution512 Resolution = "512x512"
Resolution1024 Resolution = "1024x1024"
Resolution256 Resolution = "256x256"
Resolution512 Resolution = "512x512"
Resolution1024 Resolution = "1024x1024"
Resolution10241792 Resolution = "1024x1792"
Resolution17921024 Resolution = "1792x1024"
)
const (
PicStyleVivid PicStyle = "vivid"
PicStyleNatural PicStyle = "natural"
)
const (
ModePicCreate SessionMode = "pic_create"
Expand All @@ -44,7 +52,9 @@ type SessionServiceCacheInterface interface {
GetAIMode(sessionId string) openai.AIMode
SetAIMode(sessionId string, aiMode openai.AIMode)
SetPicResolution(sessionId string, resolution Resolution)
SetPicStyle(sessionId string, resolution PicStyle)
GetPicResolution(sessionId string) string
GetPicStyle(sessionId string) string
Clear(sessionId string)
}

Expand Down Expand Up @@ -141,16 +151,45 @@ func (s *SessionService) SetMsg(sessionId string, msg []openai.Messages) {
s.cache.Set(sessionId, sessionMeta, maxCacheTime)
}

func (s *SessionService) SetPicStyle(sessionId string, style PicStyle) {
maxCacheTime := time.Hour * 12

switch style {
case PicStyleVivid, PicStyleNatural:
default:
style = PicStyleVivid
}

sessionContext, ok := s.cache.Get(sessionId)
if !ok {
sessionMeta := &SessionMeta{PicSetting: PicSetting{style: style}}
s.cache.Set(sessionId, sessionMeta, maxCacheTime)
return
}
sessionMeta := sessionContext.(*SessionMeta)
sessionMeta.PicSetting.style = style
s.cache.Set(sessionId, sessionMeta, maxCacheTime)
}

func (s *SessionService) GetPicStyle(sessionId string) string {
sessionContext, ok := s.cache.Get(sessionId)
if !ok {
return string(PicStyleVivid)
}
sessionMeta := sessionContext.(*SessionMeta)
return string(sessionMeta.PicSetting.style)
}

func (s *SessionService) SetPicResolution(sessionId string,
resolution Resolution) {
maxCacheTime := time.Hour * 12

//if not in [Resolution256, Resolution512, Resolution1024] then set
//to Resolution256
switch resolution {
case Resolution256, Resolution512, Resolution1024:
case Resolution256, Resolution512, Resolution1024, Resolution10241792, Resolution17921024:
default:
resolution = Resolution256
resolution = Resolution1024
}

sessionContext, ok := s.cache.Get(sessionId)
Expand Down

0 comments on commit 1847020

Please sign in to comment.