diff --git a/controllers/message.go b/controllers/message.go index 56c2808..b492f08 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -68,46 +68,6 @@ func (c *ApiController) GetMessage() { c.ResponseOk(message) } -func (c *ApiController) ResponseErrorStream(errorText string) { - event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) - _, err := c.Ctx.ResponseWriter.Write([]byte(event)) - if err != nil { - c.ResponseError(err.Error()) - return - } -} - -func getModelProviderFromContext(owner string, name string) (*object.Provider, error) { - var providerName string - if name != "" { - providerName = name - } else { - store, err := object.GetDefaultStore(owner) - if err != nil { - return nil, err - } - - if store != nil && store.ModelProvider != "" { - providerName = store.ModelProvider - } - } - - var provider *object.Provider - var err error - if providerName != "" { - providerId := util.GetIdFromOwnerAndName(owner, providerName) - provider, err = object.GetProvider(providerId) - } else { - provider, err = object.GetDefaultModelProvider() - } - - if provider == nil && err == nil { - return nil, fmt.Errorf("The provider: %s is not found", providerName) - } else { - return provider, err - } -} - func (c *ApiController) GetMessageAnswer() { id := c.Input().Get("id") @@ -154,13 +114,15 @@ func (c *ApiController) GetMessageAnswer() { return } - provider, err := getModelProviderFromContext(chat.Owner, chat.User2) + modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2) if err != nil { c.ResponseError(err.Error()) return } - if provider.Category != "Model" || provider.ClientSecret == "" { - c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", provider.GetId())) + + embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2) + if err != nil { + c.ResponseError(err.Error()) return } @@ -168,11 +130,10 @@ func (c *ApiController) GetMessageAnswer() { c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") - authToken := provider.ClientSecret question := questionMessage.Text var stringBuilder strings.Builder - nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question) + nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question) if err != nil && err.Error() != "no knowledge vectors found" { c.ResponseErrorStream(err.Error()) return @@ -184,13 +145,7 @@ func (c *ApiController) GetMessageAnswer() { fmt.Printf("Context: [%s]\n", nearestText) fmt.Printf("Answer: [") - modelProvider, err := provider.GetModelProvider() - if err != nil { - c.ResponseErrorStream(err.Error()) - return - } - - err = modelProvider.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder) + err = modelProviderObj.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder) if err != nil { c.ResponseErrorStream(err.Error()) return diff --git a/controllers/message_util.go b/controllers/message_util.go new file mode 100644 index 0000000..f8d02c1 --- /dev/null +++ b/controllers/message_util.go @@ -0,0 +1,111 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package controllers + +import ( + "fmt" + + "github.com/casbin/casibase/embedding" + "github.com/casbin/casibase/model" + "github.com/casbin/casibase/object" + "github.com/casbin/casibase/util" +) + +func (c *ApiController) ResponseErrorStream(errorText string) { + event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) + _, err := c.Ctx.ResponseWriter.Write([]byte(event)) + if err != nil { + c.ResponseError(err.Error()) + return + } +} + +func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) { + var providerName string + if name != "" { + providerName = name + } else { + store, err := object.GetDefaultStore(owner) + if err != nil { + return nil, err + } + + if store != nil && store.ModelProvider != "" { + providerName = store.ModelProvider + } + } + + var provider *object.Provider + var err error + if providerName != "" { + providerId := util.GetIdFromOwnerAndName(owner, providerName) + provider, err = object.GetProvider(providerId) + } else { + provider, err = object.GetDefaultModelProvider() + } + + if provider == nil && err == nil { + return nil, fmt.Errorf("The model provider: %s is not found", providerName) + } + if provider.Category != "Model" || provider.ClientSecret == "" { + return nil, fmt.Errorf("The model provider: %s is invalid", providerName) + } + + providerObj, err := provider.GetModelProvider() + if err != nil { + return nil, err + } + + return providerObj, err +} + +func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) { + var providerName string + if name != "" { + providerName = name + } else { + store, err := object.GetDefaultStore(owner) + if err != nil { + return nil, err + } + + if store != nil && store.EmbeddingProvider != "" { + providerName = store.EmbeddingProvider + } + } + + var provider *object.Provider + var err error + if providerName != "" { + providerId := util.GetIdFromOwnerAndName(owner, providerName) + provider, err = object.GetProvider(providerId) + } else { + provider, err = object.GetDefaultEmbeddingProvider() + } + + if provider == nil && err == nil { + return nil, fmt.Errorf("The embedding provider: %s is not found", providerName) + } + if provider.Category != "Embedding" || provider.ClientSecret == "" { + return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) + } + + providerObj, err := provider.GetEmbeddingProvider() + if err != nil { + return nil, err + } + + return providerObj, err +} diff --git a/model/openai_proxy.go b/embedding/openai.go similarity index 52% rename from model/openai_proxy.go rename to embedding/openai.go index 645fed6..de70991 100644 --- a/model/openai_proxy.go +++ b/embedding/openai.go @@ -12,13 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package model +package embedding import ( + "context" + "time" + "github.com/casbin/casibase/proxy" + "github.com/casbin/casibase/util" "github.com/sashabaranov/go-openai" ) +type OpenAiEmbeddingProvider struct { + subType string + secretKey string +} + +func NewOpenAiEmbeddingProvider(subType string, secretKey string) (*OpenAiEmbeddingProvider, error) { + return &OpenAiEmbeddingProvider{subType: subType, secretKey: secretKey}, nil +} + func getProxyClientFromToken(authToken string) *openai.Client { config := openai.DefaultConfig(authToken) config.HTTPClient = proxy.ProxyHttpClient @@ -26,3 +39,20 @@ func getProxyClientFromToken(authToken string) *openai.Client { c := openai.NewClientWithConfig(config) return c } + +func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) { + client := getProxyClientFromToken(p.secretKey) + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) + defer cancel() + + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Input: []string{text}, + Model: openai.EmbeddingModel(util.ParseInt(p.subType)), + }) + if err != nil { + return nil, err + } + + return resp.Data[0].Embedding, nil +} diff --git a/model/embedding.go b/embedding/provider.go similarity index 79% rename from model/embedding.go rename to embedding/provider.go index 51e9787..f643b71 100644 --- a/model/embedding.go +++ b/embedding/provider.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package model +package embedding import ( "context" @@ -22,6 +22,23 @@ import ( "github.com/sashabaranov/go-openai" ) +type EmbeddingProvider interface { + QueryVector(text string, timeout int) ([]float32, error) +} + +func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { + var p EmbeddingProvider + var err error + if typ == "OpenAI" { + p, err = NewOpenAiEmbeddingProvider(subType, clientSecret) + } + + if err != nil { + return nil, err + } + return p, nil +} + func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { client := getProxyClientFromToken(authToken) diff --git a/model/openai.go b/model/openai.go index bf4e09e..5414826 100644 --- a/model/openai.go +++ b/model/openai.go @@ -21,6 +21,7 @@ import ( "net/http" "strings" + "github.com/casbin/casibase/proxy" "github.com/sashabaranov/go-openai" ) @@ -33,6 +34,14 @@ func NewOpenAiModelProvider(subType string, secretKey string) (*OpenAiModelProvi return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil } +func getProxyClientFromToken(authToken string) *openai.Client { + config := openai.DefaultConfig(authToken) + config.HTTPClient = proxy.ProxyHttpClient + + c := openai.NewClientWithConfig(config) + return c +} + func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { client := getProxyClientFromToken(p.secretKey) diff --git a/model/model.go b/model/provider.go similarity index 100% rename from model/model.go rename to model/provider.go diff --git a/object/provider.go b/object/provider.go index 77cb947..9459f23 100644 --- a/object/provider.go +++ b/object/provider.go @@ -17,6 +17,7 @@ package object import ( "fmt" + "github.com/casbin/casibase/embedding" "github.com/casbin/casibase/model" "github.com/casbin/casibase/util" "xorm.io/core" @@ -116,6 +117,20 @@ func GetDefaultModelProvider() (*Provider, error) { return &provider, nil } +func GetDefaultEmbeddingProvider() (*Provider, error) { + provider := Provider{Owner: "admin", Category: "Embedding"} + existed, err := adapter.engine.Get(&provider) + if err != nil { + return &provider, err + } + + if !existed { + return nil, nil + } + + return &provider, nil +} + func UpdateProvider(id string, provider *Provider) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) p, err := getProvider(owner, name) @@ -173,3 +188,16 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { return pProvider, nil } + +func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { + pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret) + if err != nil { + return nil, err + } + + if pProvider == nil { + return nil, fmt.Errorf("the embedding provider type: %s is not supported", p.Type) + } + + return pProvider, nil +} diff --git a/object/store.go b/object/store.go index 227c029..caf5d53 100644 --- a/object/store.go +++ b/object/store.go @@ -44,8 +44,9 @@ type Store struct { CreatedTime string `xorm:"varchar(100)" json:"createdTime"` DisplayName string `xorm:"varchar(100)" json:"displayName"` - StorageProvider string `xorm:"varchar(100)" json:"storageProvider"` - ModelProvider string `xorm:"varchar(100)" json:"modelProvider"` + StorageProvider string `xorm:"varchar(100)" json:"storageProvider"` + ModelProvider string `xorm:"varchar(100)" json:"modelProvider"` + EmbeddingProvider string `xorm:"varchar(100)" json:"embeddingProvider"` FileTree *File `xorm:"mediumtext" json:"fileTree"` PropertiesMap map[string]*Properties `xorm:"mediumtext" json:"propertiesMap"` @@ -150,22 +151,26 @@ func (store *Store) GetId() string { return fmt.Sprintf("%s/%s", store.Owner, store.Name) } -func (store *Store) GetModelProvider() (*Provider, error) { - if store.ModelProvider == "" { - return GetDefaultModelProvider() +func (store *Store) GetEmbeddingProvider() (*Provider, error) { + if store.EmbeddingProvider == "" { + return GetDefaultEmbeddingProvider() } - providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider) + providerId := util.GetIdFromOwnerAndName(store.Owner, store.EmbeddingProvider) return GetProvider(providerId) } func RefreshStoreVectors(store *Store) (bool, error) { - provider, err := store.GetModelProvider() + embeddingProvider, err := store.GetEmbeddingProvider() if err != nil { return false, err } - authToken := provider.ClientSecret - ok, err := addVectorsForStore(authToken, store.StorageProvider, "", store.Name) + embeddingProviderObj, err := embeddingProvider.GetEmbeddingProvider() + if err != nil { + return false, err + } + + ok, err := addVectorsForStore(embeddingProviderObj, store.StorageProvider, "", store.Name) return ok, err } diff --git a/object/vector_embedding.go b/object/vector_embedding.go index de018e5..32a2128 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -20,7 +20,7 @@ import ( "path/filepath" "time" - "github.com/casbin/casibase/model" + "github.com/casbin/casibase/embedding" "github.com/casbin/casibase/storage" "github.com/casbin/casibase/txt" "github.com/casbin/casibase/util" @@ -53,8 +53,9 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, return filterTextFiles(files), nil } -func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { - embedding, err := model.GetEmbeddingSafe(authToken, text) +func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) { + data, err := embeddingProviderObj.QueryVector(text, 5) + // data, err := model.GetEmbeddingSafe(authToken, text) if err != nil { return false, err } @@ -72,16 +73,16 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName Store: storeName, File: fileName, Text: text, - Data: embedding, + Data: data, } return AddVector(vector) } -func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) { +func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storageProviderName string, key string, storeName string) (bool, error) { var affected bool var err error - objs, err := getFilteredFileObjects(provider, key) + objs, err := getFilteredFileObjects(storageProviderName, key) if err != nil { return false, err } @@ -99,7 +100,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName for i, textSection := range textSections { if timeLimiter.Allow() { fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) - affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key) + affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) } else { err = timeLimiter.Wait(context.Background()) if err != nil { @@ -107,7 +108,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName } fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) - affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key) + affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) } } } @@ -127,8 +128,9 @@ func getRelatedVectors(owner string) ([]*Vector, error) { return vectors, nil } -func GetNearestVectorText(authToken string, owner string, question string) (string, error) { - qVector, err := model.GetEmbeddingSafe(authToken, question) +func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { + qVector, err := embeddingProvider.QueryVector(text, 5) + // qVector, err := embedding.GetEmbeddingSafe(authToken, question) if err != nil { return "", err } diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index 05fb73b..de7a49e 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -102,8 +102,7 @@ class ProviderEditPage extends React.Component { { [ {id: "Model", name: "Model"}, - {id: "Vector Database", name: "Vector Database"}, - {id: "Storage", name: "Storage"}, + {id: "Embedding", name: "Embedding"}, ].map((item, index) => ) } @@ -116,11 +115,9 @@ class ProviderEditPage extends React.Component {