| @@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) { | |||
| } | |||
| } | |||
| func (store *Store) GetModelProvider() (*Provider, error) { | |||
| if store.ModelProvider == "" { | |||
| return GetDefaultModelProvider() | |||
| } | |||
| providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider) | |||
| return GetProvider(providerId) | |||
| } | |||
| func (store *Store) GetEmbeddingProvider() (*Provider, error) { | |||
| if store.EmbeddingProvider == "" { | |||
| return GetDefaultEmbeddingProvider() | |||
| @@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) { | |||
| return false, err | |||
| } | |||
| modelProvider, err := store.GetModelProvider() | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| embeddingProvider, err := store.GetEmbeddingProvider() | |||
| if err != nil { | |||
| return false, err | |||
| @@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) { | |||
| return false, err | |||
| } | |||
| ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name) | |||
| ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType) | |||
| return ok, err | |||
| } | |||
| @@ -32,6 +32,7 @@ type Vector struct { | |||
| File string `xorm:"varchar(100)" json:"file"` | |||
| Index int `json:"index"` | |||
| Text string `xorm:"mediumtext" json:"text"` | |||
| Size int `json:"size"` | |||
| Score float32 `json:"score"` | |||
| Data []float32 `xorm:"mediumtext" json:"data"` | |||
| @@ -22,6 +22,7 @@ import ( | |||
| "time" | |||
| "github.com/casbin/casibase/embedding" | |||
| "github.com/casbin/casibase/model" | |||
| "github.com/casbin/casibase/storage" | |||
| "github.com/casbin/casibase/txt" | |||
| "github.com/casbin/casibase/util" | |||
| @@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object { | |||
| return res | |||
| } | |||
| func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) { | |||
| func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) { | |||
| data, err := queryVectorSafe(embeddingProviderObj, text) | |||
| if err != nil { | |||
| return false, err | |||
| @@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st | |||
| displayName = text[:25] | |||
| } | |||
| size, err := model.GetTokenSize(modelSubType, text) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| vector := &Vector{ | |||
| Owner: "admin", | |||
| Name: fmt.Sprintf("vector_%s", util.GetRandomName()), | |||
| @@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st | |||
| File: fileName, | |||
| Index: index, | |||
| Text: text, | |||
| Size: size, | |||
| Data: data, | |||
| Dimension: len(data), | |||
| } | |||
| return AddVector(vector) | |||
| } | |||
| func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) { | |||
| func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) { | |||
| var affected bool | |||
| files, err := storageProviderObj.ListObjects(prefix) | |||
| @@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro | |||
| if timeLimiter.Allow() { | |||
| fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | |||
| affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName) | |||
| affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType) | |||
| } else { | |||
| err = timeLimiter.Wait(context.Background()) | |||
| if err != nil { | |||
| @@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro | |||
| } | |||
| fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | |||
| affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName) | |||
| affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType) | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package object | |||
| import ( | |||
| "testing" | |||
| "github.com/casbin/casibase/model" | |||
| ) | |||
| func TestUpdateVectors(t *testing.T) { | |||
| InitConfig() | |||
| vectors, err := GetGlobalVectors() | |||
| if err != nil { | |||
| panic(err) | |||
| } | |||
| for _, vector := range vectors { | |||
| if vector.Text != "" && vector.Size == 0 { | |||
| vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text) | |||
| if err != nil { | |||
| panic(err) | |||
| } | |||
| _, err = UpdateVector(vector.GetId(), vector) | |||
| if err != nil { | |||
| panic(err) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -132,6 +132,16 @@ class VectorEditPage extends React.Component { | |||
| }} /> | |||
| </Col> | |||
| </Row> | |||
| <Row style={{marginTop: "20px"}} > | |||
| <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
| {i18next.t("vector:Size")}: | |||
| </Col> | |||
| <Col span={22} > | |||
| <InputNumber disabled={true} value={this.state.vector.size} onChange={value => { | |||
| this.updateVectorField("size", value); | |||
| }} /> | |||
| </Col> | |||
| </Row> | |||
| <Row style={{marginTop: "20px"}} > | |||
| <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
| {i18next.t("vector:Dimension")}: | |||
| @@ -176,6 +176,13 @@ class VectorListPage extends React.Component { | |||
| ); | |||
| }, | |||
| }, | |||
| { | |||
| title: i18next.t("vector:Size"), | |||
| dataIndex: "size", | |||
| key: "size", | |||
| width: "80px", | |||
| sorter: (a, b) => a.size - b.size, | |||
| }, | |||
| { | |||
| title: i18next.t("vector:Data"), | |||
| dataIndex: "data", | |||