Browse Source

Check vector's embedding provider

master
Yang Luo 2 years ago
parent
commit
b7acf39c17
3 changed files with 21 additions and 17 deletions
  1. +3
    -3
      controllers/message.go
  2. +12
    -12
      controllers/message_util.go
  3. +6
    -2
      object/vector_embedding.go

+ 3
- 3
controllers/message.go View File

@@ -114,13 +114,13 @@ func (c *ApiController) GetMessageAnswer() {
return return
} }


modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2)
_, modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2)
if err != nil { if err != nil {
c.ResponseErrorStream(err.Error()) c.ResponseErrorStream(err.Error())
return return
} }


embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2)
embeddingProvider, embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2)
if err != nil { if err != nil {
c.ResponseErrorStream(err.Error()) c.ResponseErrorStream(err.Error())
return return
@@ -132,7 +132,7 @@ func (c *ApiController) GetMessageAnswer() {


question := questionMessage.Text question := questionMessage.Text


knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question)
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProvider, embeddingProviderObj, chat.Owner, question)
if err != nil && err.Error() != "no knowledge vectors found" { if err != nil && err.Error() != "no knowledge vectors found" {
c.ResponseErrorStream(err.Error()) c.ResponseErrorStream(err.Error())
return return


+ 12
- 12
controllers/message_util.go View File

@@ -32,14 +32,14 @@ func (c *ApiController) ResponseErrorStream(errorText string) {
} }
} }


func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) {
func getModelProviderFromContext(owner string, name string) (*object.Provider, model.ModelProvider, error) {
var providerName string var providerName string
if name != "" { if name != "" {
providerName = name providerName = name
} else { } else {
store, err := object.GetDefaultStore(owner) store, err := object.GetDefaultStore(owner)
if err != nil { if err != nil {
return nil, err
return nil, nil, err
} }


if store != nil && store.ModelProvider != "" { if store != nil && store.ModelProvider != "" {
@@ -57,28 +57,28 @@ func getModelProviderFromContext(owner string, name string) (model.ModelProvider
} }


if provider == nil && err == nil { if provider == nil && err == nil {
return nil, fmt.Errorf("The model provider: %s is not found", providerName)
return nil, nil, fmt.Errorf("The model provider: %s is not found", providerName)
} }
if provider.Category != "Model" || provider.ClientSecret == "" { if provider.Category != "Model" || provider.ClientSecret == "" {
return nil, fmt.Errorf("The model provider: %s is invalid", providerName)
return nil, nil, fmt.Errorf("The model provider: %s is invalid", providerName)
} }


providerObj, err := provider.GetModelProvider() providerObj, err := provider.GetModelProvider()
if err != nil { if err != nil {
return nil, err
return nil, nil, err
} }


return providerObj, err
return provider, providerObj, err
} }


func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) {
func getEmbeddingProviderFromContext(owner string, name string) (*object.Provider, embedding.EmbeddingProvider, error) {
var providerName string var providerName string
if name != "" { if name != "" {
providerName = name providerName = name
} else { } else {
store, err := object.GetDefaultStore(owner) store, err := object.GetDefaultStore(owner)
if err != nil { if err != nil {
return nil, err
return nil, nil, err
} }


if store != nil && store.EmbeddingProvider != "" { if store != nil && store.EmbeddingProvider != "" {
@@ -96,16 +96,16 @@ func getEmbeddingProviderFromContext(owner string, name string) (embedding.Embed
} }


if provider == nil && err == nil { if provider == nil && err == nil {
return nil, fmt.Errorf("The embedding provider: %s is not found", providerName)
return nil, nil, fmt.Errorf("The embedding provider: %s is not found", providerName)
} }
if provider.Category != "Embedding" || provider.ClientSecret == "" { if provider.Category != "Embedding" || provider.ClientSecret == "" {
return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName)
return nil, nil, fmt.Errorf("The embedding provider: %s is invalid", providerName)
} }


providerObj, err := provider.GetEmbeddingProvider() providerObj, err := provider.GetEmbeddingProvider()
if err != nil { if err != nil {
return nil, err
return nil, nil, err
} }


return providerObj, err
return provider, providerObj, err
} }

+ 6
- 2
object/vector_embedding.go View File

@@ -161,8 +161,8 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string)
} }
} }


func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
qVector, err := queryVectorSafe(embeddingProvider, text)
func GetNearestKnowledge(embeddingProvider *Provider, embeddingProviderObj embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
qVector, err := queryVectorSafe(embeddingProviderObj, text)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@@ -183,6 +183,10 @@ func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner st
vectorScores := []VectorScore{} vectorScores := []VectorScore{}
texts := []string{} texts := []string{}
for _, vector := range vectors { for _, vector := range vectors {
if embeddingProvider.Name != vector.Provider {
return "", nil, fmt.Errorf("The store's embedding provider: [%s] should equal to vector's embedding provider: [%s], vector = %v", embeddingProvider.Name, vector.Provider, vector)
}

vectorScores = append(vectorScores, VectorScore{ vectorScores = append(vectorScores, VectorScore{
Vector: vector.Name, Vector: vector.Name,
Score: vector.Score, Score: vector.Score,


Loading…
Cancel
Save