diff --git a/object/store.go b/object/store.go
index a63c5c0..180c4c1 100644
--- a/object/store.go
+++ b/object/store.go
@@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) {
}
}
+func (store *Store) GetModelProvider() (*Provider, error) {
+ if store.ModelProvider == "" {
+ return GetDefaultModelProvider()
+ }
+
+ providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider)
+ return GetProvider(providerId)
+}
+
func (store *Store) GetEmbeddingProvider() (*Provider, error) {
if store.EmbeddingProvider == "" {
return GetDefaultEmbeddingProvider()
@@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) {
return false, err
}
+ modelProvider, err := store.GetModelProvider()
+ if err != nil {
+ return false, err
+ }
+
embeddingProvider, err := store.GetEmbeddingProvider()
if err != nil {
return false, err
@@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) {
return false, err
}
- ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name)
+ ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType)
return ok, err
}
diff --git a/object/vector.go b/object/vector.go
index 38ca6ae..e154728 100644
--- a/object/vector.go
+++ b/object/vector.go
@@ -32,6 +32,7 @@ type Vector struct {
File string `xorm:"varchar(100)" json:"file"`
Index int `json:"index"`
Text string `xorm:"mediumtext" json:"text"`
+ Size int `json:"size"`
Score float32 `json:"score"`
Data []float32 `xorm:"mediumtext" json:"data"`
diff --git a/object/vector_embedding.go b/object/vector_embedding.go
index f07db19..aa3c1d2 100644
--- a/object/vector_embedding.go
+++ b/object/vector_embedding.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/casbin/casibase/embedding"
+ "github.com/casbin/casibase/model"
"github.com/casbin/casibase/storage"
"github.com/casbin/casibase/txt"
"github.com/casbin/casibase/util"
@@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object {
return res
}
-func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) {
+func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) {
data, err := queryVectorSafe(embeddingProviderObj, text)
if err != nil {
return false, err
@@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
displayName = text[:25]
}
+ size, err := model.GetTokenSize(modelSubType, text)
+ if err != nil {
+ return false, err
+ }
+
vector := &Vector{
Owner: "admin",
Name: fmt.Sprintf("vector_%s", util.GetRandomName()),
@@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st
File: fileName,
Index: index,
Text: text,
+ Size: size,
Data: data,
Dimension: len(data),
}
return AddVector(vector)
}
-func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) {
+func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) {
var affected bool
files, err := storageProviderObj.ListObjects(prefix)
@@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
if timeLimiter.Allow() {
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
- affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
+ affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
} else {
err = timeLimiter.Wait(context.Background())
if err != nil {
@@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro
}
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
- affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName)
+ affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType)
}
}
}
diff --git a/object/vector_test.go b/object/vector_test.go
new file mode 100644
index 0000000..606216e
--- /dev/null
+++ b/object/vector_test.go
@@ -0,0 +1,44 @@
+// Copyright 2023 The casbin Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package object
+
+import (
+ "testing"
+
+ "github.com/casbin/casibase/model"
+)
+
+func TestUpdateVectors(t *testing.T) {
+ InitConfig()
+
+ vectors, err := GetGlobalVectors()
+ if err != nil {
+ panic(err)
+ }
+
+ for _, vector := range vectors {
+ if vector.Text != "" && vector.Size == 0 {
+ vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text)
+ if err != nil {
+ panic(err)
+ }
+
+ _, err = UpdateVector(vector.GetId(), vector)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+}
diff --git a/web/src/VectorEditPage.js b/web/src/VectorEditPage.js
index 79ed029..33d587c 100644
--- a/web/src/VectorEditPage.js
+++ b/web/src/VectorEditPage.js
@@ -132,6 +132,16 @@ class VectorEditPage extends React.Component {
}} />
+
+
+ {i18next.t("vector:Size")}:
+
+
+ {
+ this.updateVectorField("size", value);
+ }} />
+
+
{i18next.t("vector:Dimension")}:
diff --git a/web/src/VectorListPage.js b/web/src/VectorListPage.js
index e615095..f2fd996 100644
--- a/web/src/VectorListPage.js
+++ b/web/src/VectorListPage.js
@@ -176,6 +176,13 @@ class VectorListPage extends React.Component {
);
},
},
+ {
+ title: i18next.t("vector:Size"),
+ dataIndex: "size",
+ key: "size",
+ width: "80px",
+ sorter: (a, b) => a.size - b.size,
+ },
{
title: i18next.t("vector:Data"),
dataIndex: "data",