Skip to content

Commit

Permalink
Add user field support in conversation settings and enhance user ID r…
Browse files Browse the repository at this point in the history
…etrieval in Xun

- Introduced a new UserField in the Setting struct to allow customization of the user ID field name.
- Implemented getUserID method in Xun to retrieve user IDs based on the configured UserField, improving flexibility in user identification.
- Updated multiple methods (UpdateChatTitle, GetChats, GetHistory, SaveHistory, GetRequest, SaveRequest, GetChat) to utilize the new user ID retrieval logic, ensuring consistent handling of user sessions.
- Enhanced error handling for user ID retrieval, ensuring robust feedback in case of issues.
  • Loading branch information
trheyi committed Dec 17, 2024
1 parent b3bc843 commit f666070
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 13 deletions.
1 change: 1 addition & 0 deletions neo/conversation/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package conversation
// Setting the conversation config
type Setting struct {
Connector string `json:"connector,omitempty"`
UserField string `json:"user_field,omitempty"` // the user id field name, default is user_id
Table string `json:"table,omitempty"`
MaxSize int `json:"max_size,omitempty" yaml:"max_size,omitempty"`
TTL int `json:"ttl,omitempty" yaml:"ttl,omitempty"`
Expand Down
87 changes: 74 additions & 13 deletions neo/conversation/xun.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"strings"
"time"

"github.com/google/uuid"
"github.com/yaoapp/gou/connector"
"github.com/yaoapp/gou/session"
"github.com/yaoapp/kun/log"
"github.com/yaoapp/xun/capsule"
"github.com/yaoapp/xun/dbal/query"
Expand All @@ -21,8 +23,7 @@ type Xun struct {

type row struct {
Role string `json:"role"`
Title string `json:"title"` // Chat title
Name string `json:"name"` // User name
Name string `json:"name"` // User name
Content string `json:"content"`
Sid string `json:"sid"`
Rid string `json:"rid"`
Expand Down Expand Up @@ -196,6 +197,24 @@ func (conv *Xun) initChatTable() error {
return nil
}

func (conv *Xun) getUserID(sid string) (string, error) {
field := "user_id"
if conv.setting.UserField != "" {
field = conv.setting.UserField
}

id, err := session.Global().ID(sid).Get(field)
if err != nil {
return "", err
}

if id == nil || id == "" {
return sid, nil
}

return fmt.Sprintf("%v", id), nil
}

func (conv *Xun) getHistoryTable() string {
return conv.setting.Table
}
Expand All @@ -206,8 +225,13 @@ func (conv *Xun) getChatTable() string {

// UpdateChatTitle update the chat title
func (conv *Xun) UpdateChatTitle(sid string, cid string, title string) error {
_, err := conv.newQueryChat().
Where("sid", sid).
userID, err := conv.getUserID(sid)
if err != nil {
return err
}

_, err = conv.newQueryChat().
Where("sid", userID).
Where("chat_id", cid).
Update(map[string]interface{}{
"title": title,
Expand All @@ -218,9 +242,14 @@ func (conv *Xun) UpdateChatTitle(sid string, cid string, title string) error {

// GetChats get the chat list
func (conv *Xun) GetChats(sid string, keywords ...string) ([]map[string]interface{}, error) {
userID, err := conv.getUserID(sid)
if err != nil {
return nil, err
}

qb := conv.newQueryChat().
Select("chat_id", "title").
Where("sid", sid)
Where("sid", userID)

// Add title search if keywords provided
if len(keywords) > 0 && keywords[0] != "" {
Expand Down Expand Up @@ -248,10 +277,14 @@ func (conv *Xun) GetChats(sid string, keywords ...string) ([]map[string]interfac

// GetHistory get the history
func (conv *Xun) GetHistory(sid string, cid string) ([]map[string]interface{}, error) {
userID, err := conv.getUserID(sid)
if err != nil {
return nil, err
}

qb := conv.newQuery().
Select("role", "name", "content").
Where("sid", sid).
Where("sid", userID).
Where("cid", cid).
OrderBy("id", "desc")

Expand Down Expand Up @@ -283,10 +316,20 @@ func (conv *Xun) GetHistory(sid string, cid string) ([]map[string]interface{}, e

// SaveHistory save the history
func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid string) error {

if cid == "" {
cid = uuid.New().String() // Generate a new UUID if cid is empty
}

userID, err := conv.getUserID(sid)
if err != nil {
return err
}

// First ensure chat record exists
exists, err := conv.newQueryChat().
Where("chat_id", cid).
Where("sid", sid).
Where("sid", userID).
Exists()

if err != nil {
Expand All @@ -298,7 +341,7 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid
err = conv.newQueryChat().
Insert(map[string]interface{}{
"chat_id": cid,
"sid": sid,
"sid": userID,
"created_at": time.Now(),
})

Expand All @@ -320,7 +363,7 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid
Role: message["role"].(string),
Name: "",
Content: message["content"].(string),
Sid: sid,
Sid: userID,
Cid: cid,
ExpiredAt: expiredAt,
}
Expand All @@ -331,16 +374,25 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid
values = append(values, value)
}

return conv.newQuery().Insert(values)
err = conv.newQuery().Insert(values)
if err != nil {
return err
}

return nil
}

// GetRequest get the request history
func (conv *Xun) GetRequest(sid string, rid string) ([]map[string]interface{}, error) {
userID, err := conv.getUserID(sid)
if err != nil {
return nil, err
}

qb := conv.newQuery().
Select("role", "name", "content", "sid").
Where("rid", rid).
Where("sid", sid).
Where("sid", userID).
OrderBy("id", "desc")

if conv.setting.TTL > 0 {
Expand Down Expand Up @@ -371,6 +423,10 @@ func (conv *Xun) GetRequest(sid string, rid string) ([]map[string]interface{}, e

// SaveRequest save the request history
func (conv *Xun) SaveRequest(sid string, rid string, cid string, messages []map[string]interface{}) error {
userID, err := conv.getUserID(sid)
if err != nil {
return err
}

defer conv.clean()
var expiredAt interface{} = nil
Expand All @@ -384,7 +440,7 @@ func (conv *Xun) SaveRequest(sid string, rid string, cid string, messages []map[
Role: message["role"].(string),
Name: "",
Content: message["content"].(string),
Sid: sid,
Sid: userID,
Cid: cid,
Rid: rid,
ExpiredAt: expiredAt,
Expand All @@ -401,10 +457,15 @@ func (conv *Xun) SaveRequest(sid string, rid string, cid string, messages []map[

// GetChat get the chat info and its history
func (conv *Xun) GetChat(sid string, cid string) (*ChatInfo, error) {
userID, err := conv.getUserID(sid)
if err != nil {
return nil, err
}

// Get chat info
qb := conv.newQueryChat().
Select("chat_id", "title").
Where("sid", sid).
Where("sid", userID).
Where("chat_id", cid)

row, err := qb.First()
Expand Down

0 comments on commit f666070

Please sign in to comment.