diff --git a/controllers/message.go b/controllers/message.go index c766e54..fb77e9a 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -137,7 +137,7 @@ func (c *ApiController) GetMessageAnswer() { return } - if provider.Category != "AI" || provider.ClientSecret == "" { + if provider.Category != "Model" || provider.ClientSecret == "" { c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", providerId)) return } diff --git a/object/chat.go b/object/chat.go index 702cb5f..6553f73 100644 --- a/object/chat.go +++ b/object/chat.go @@ -96,6 +96,17 @@ func UpdateChat(id string, chat *Chat) (bool, error) { } func AddChat(chat *Chat) (bool, error) { + if chat.Type == "AI" && chat.User2 == "" { + provider, err := getDefaultModelProvider() + if err != nil { + return false, err + } + + if provider != nil { + chat.User2 = provider.Name + } + } + affected, err := adapter.engine.Insert(chat) if err != nil { return false, err diff --git a/object/provider.go b/object/provider.go index 25983bc..a8da123 100644 --- a/object/provider.go +++ b/object/provider.go @@ -100,6 +100,20 @@ func GetProvider(id string) (*Provider, error) { return getProvider(owner, name) } +func getDefaultModelProvider() (*Provider, error) { + provider := Provider{Owner: "admin", Category: "Model"} + existed, err := adapter.engine.Get(&provider) + if err != nil { + return &provider, err + } + + if !existed { + return nil, nil + } + + return &provider, nil +} + func UpdateProvider(id string, provider *Provider) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) _, err := getProvider(owner, name)