From b7acf39c17d81e083d225fe9a8cbd18e76934e3c Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Sat, 30 Sep 2023 19:06:19 +0800 Subject: [PATCH] Check vector's embedding provider --- controllers/message.go | 6 +++--- controllers/message_util.go | 24 ++++++++++++------------ object/vector_embedding.go | 8 ++++++-- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/controllers/message.go b/controllers/message.go index ab67f40..80670cb 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -114,13 +114,13 @@ func (c *ApiController) GetMessageAnswer() { return } - modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2) + _, modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2) if err != nil { c.ResponseErrorStream(err.Error()) return } - embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2) + embeddingProvider, embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2) if err != nil { c.ResponseErrorStream(err.Error()) return @@ -132,7 +132,7 @@ func (c *ApiController) GetMessageAnswer() { question := questionMessage.Text - knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question) + knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProvider, embeddingProviderObj, chat.Owner, question) if err != nil && err.Error() != "no knowledge vectors found" { c.ResponseErrorStream(err.Error()) return diff --git a/controllers/message_util.go b/controllers/message_util.go index f8d02c1..9cf3aea 100644 --- a/controllers/message_util.go +++ b/controllers/message_util.go @@ -32,14 +32,14 @@ func (c *ApiController) ResponseErrorStream(errorText string) { } } -func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) { +func getModelProviderFromContext(owner string, name string) (*object.Provider, model.ModelProvider, error) { var providerName string if name != "" { providerName = name } else { store, err := object.GetDefaultStore(owner) if err != nil { - return nil, err + return nil, nil, err } if store != nil && store.ModelProvider != "" { @@ -57,28 +57,28 @@ func getModelProviderFromContext(owner string, name string) (model.ModelProvider } if provider == nil && err == nil { - return nil, fmt.Errorf("The model provider: %s is not found", providerName) + return nil, nil, fmt.Errorf("The model provider: %s is not found", providerName) } if provider.Category != "Model" || provider.ClientSecret == "" { - return nil, fmt.Errorf("The model provider: %s is invalid", providerName) + return nil, nil, fmt.Errorf("The model provider: %s is invalid", providerName) } providerObj, err := provider.GetModelProvider() if err != nil { - return nil, err + return nil, nil, err } - return providerObj, err + return provider, providerObj, err } -func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) { +func getEmbeddingProviderFromContext(owner string, name string) (*object.Provider, embedding.EmbeddingProvider, error) { var providerName string if name != "" { providerName = name } else { store, err := object.GetDefaultStore(owner) if err != nil { - return nil, err + return nil, nil, err } if store != nil && store.EmbeddingProvider != "" { @@ -96,16 +96,16 @@ func getEmbeddingProviderFromContext(owner string, name string) (embedding.Embed } if provider == nil && err == nil { - return nil, fmt.Errorf("The embedding provider: %s is not found", providerName) + return nil, nil, fmt.Errorf("The embedding provider: %s is not found", providerName) } if provider.Category != "Embedding" || provider.ClientSecret == "" { - return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) + return nil, nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) } providerObj, err := provider.GetEmbeddingProvider() if err != nil { - return nil, err + return nil, nil, err } - return providerObj, err + return provider, providerObj, err } diff --git a/object/vector_embedding.go b/object/vector_embedding.go index 29e7c9c..7e1f22b 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -161,8 +161,8 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) } } -func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) { - qVector, err := queryVectorSafe(embeddingProvider, text) +func GetNearestKnowledge(embeddingProvider *Provider, embeddingProviderObj embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) { + qVector, err := queryVectorSafe(embeddingProviderObj, text) if err != nil { return "", nil, err } @@ -183,6 +183,10 @@ func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner st vectorScores := []VectorScore{} texts := []string{} for _, vector := range vectors { + if embeddingProvider.Name != vector.Provider { + return "", nil, fmt.Errorf("The store's embedding provider: [%s] should equal to vector's embedding provider: [%s], vector = %v", embeddingProvider.Name, vector.Provider, vector) + } + vectorScores = append(vectorScores, VectorScore{ Vector: vector.Name, Score: vector.Score,