From abae92b96334af01468853595b05c9417772d819 Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Sun, 1 Oct 2023 10:05:52 +0800 Subject: [PATCH] Add size to vector --- object/store.go | 16 +++++++++++++- object/vector.go | 1 + object/vector_embedding.go | 15 +++++++++---- object/vector_test.go | 44 ++++++++++++++++++++++++++++++++++++++ web/src/VectorEditPage.js | 10 +++++++++ web/src/VectorListPage.js | 7 ++++++ 6 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 object/vector_test.go diff --git a/object/store.go b/object/store.go index a63c5c0..180c4c1 100644 --- a/object/store.go +++ b/object/store.go @@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) { } } +func (store *Store) GetModelProvider() (*Provider, error) { + if store.ModelProvider == "" { + return GetDefaultModelProvider() + } + + providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider) + return GetProvider(providerId) +} + func (store *Store) GetEmbeddingProvider() (*Provider, error) { if store.EmbeddingProvider == "" { return GetDefaultEmbeddingProvider() @@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) { return false, err } + modelProvider, err := store.GetModelProvider() + if err != nil { + return false, err + } + embeddingProvider, err := store.GetEmbeddingProvider() if err != nil { return false, err @@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) { return false, err } - ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name) + ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType) return ok, err } diff --git a/object/vector.go b/object/vector.go index 38ca6ae..e154728 100644 --- a/object/vector.go +++ b/object/vector.go @@ -32,6 +32,7 @@ type Vector struct { File string `xorm:"varchar(100)" json:"file"` Index int `json:"index"` Text string `xorm:"mediumtext" json:"text"` + Size int `json:"size"` Score float32 `json:"score"` Data []float32 `xorm:"mediumtext" json:"data"` diff --git a/object/vector_embedding.go b/object/vector_embedding.go index f07db19..aa3c1d2 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -22,6 +22,7 @@ import ( "time" "github.com/casbin/casibase/embedding" + "github.com/casbin/casibase/model" "github.com/casbin/casibase/storage" "github.com/casbin/casibase/txt" "github.com/casbin/casibase/util" @@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object { return res } -func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) { +func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) { data, err := queryVectorSafe(embeddingProviderObj, text) if err != nil { return false, err @@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st displayName = text[:25] } + size, err := model.GetTokenSize(modelSubType, text) + if err != nil { + return false, err + } + vector := &Vector{ Owner: "admin", Name: fmt.Sprintf("vector_%s", util.GetRandomName()), @@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st File: fileName, Index: index, Text: text, + Size: size, Data: data, Dimension: len(data), } return AddVector(vector) } -func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) { +func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) { var affected bool files, err := storageProviderObj.ListObjects(prefix) @@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro if timeLimiter.Allow() { fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) - affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName) + affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType) } else { err = timeLimiter.Wait(context.Background()) if err != nil { @@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro } fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) - affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName) + affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType) } } } diff --git a/object/vector_test.go b/object/vector_test.go new file mode 100644 index 0000000..606216e --- /dev/null +++ b/object/vector_test.go @@ -0,0 +1,44 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package object + +import ( + "testing" + + "github.com/casbin/casibase/model" +) + +func TestUpdateVectors(t *testing.T) { + InitConfig() + + vectors, err := GetGlobalVectors() + if err != nil { + panic(err) + } + + for _, vector := range vectors { + if vector.Text != "" && vector.Size == 0 { + vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text) + if err != nil { + panic(err) + } + + _, err = UpdateVector(vector.GetId(), vector) + if err != nil { + panic(err) + } + } + } +} diff --git a/web/src/VectorEditPage.js b/web/src/VectorEditPage.js index 79ed029..33d587c 100644 --- a/web/src/VectorEditPage.js +++ b/web/src/VectorEditPage.js @@ -132,6 +132,16 @@ class VectorEditPage extends React.Component { }} /> + + + {i18next.t("vector:Size")}: + + + { + this.updateVectorField("size", value); + }} /> + + {i18next.t("vector:Dimension")}: diff --git a/web/src/VectorListPage.js b/web/src/VectorListPage.js index e615095..f2fd996 100644 --- a/web/src/VectorListPage.js +++ b/web/src/VectorListPage.js @@ -176,6 +176,13 @@ class VectorListPage extends React.Component { ); }, }, + { + title: i18next.t("vector:Size"), + dataIndex: "size", + key: "size", + width: "80px", + sorter: (a, b) => a.size - b.size, + }, { title: i18next.t("vector:Data"), dataIndex: "data",