@@ -172,6 +172,15 @@ func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) { | |||
} | |||
} | |||
func (store *Store) GetModelProvider() (*Provider, error) { | |||
if store.ModelProvider == "" { | |||
return GetDefaultModelProvider() | |||
} | |||
providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider) | |||
return GetProvider(providerId) | |||
} | |||
func (store *Store) GetEmbeddingProvider() (*Provider, error) { | |||
if store.EmbeddingProvider == "" { | |||
return GetDefaultEmbeddingProvider() | |||
@@ -187,6 +196,11 @@ func RefreshStoreVectors(store *Store) (bool, error) { | |||
return false, err | |||
} | |||
modelProvider, err := store.GetModelProvider() | |||
if err != nil { | |||
return false, err | |||
} | |||
embeddingProvider, err := store.GetEmbeddingProvider() | |||
if err != nil { | |||
return false, err | |||
@@ -197,6 +211,6 @@ func RefreshStoreVectors(store *Store) (bool, error) { | |||
return false, err | |||
} | |||
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name) | |||
ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name, embeddingProvider.Name, modelProvider.SubType) | |||
return ok, err | |||
} |
@@ -32,6 +32,7 @@ type Vector struct { | |||
File string `xorm:"varchar(100)" json:"file"` | |||
Index int `json:"index"` | |||
Text string `xorm:"mediumtext" json:"text"` | |||
Size int `json:"size"` | |||
Score float32 `json:"score"` | |||
Data []float32 `xorm:"mediumtext" json:"data"` | |||
@@ -22,6 +22,7 @@ import ( | |||
"time" | |||
"github.com/casbin/casibase/embedding" | |||
"github.com/casbin/casibase/model" | |||
"github.com/casbin/casibase/storage" | |||
"github.com/casbin/casibase/txt" | |||
"github.com/casbin/casibase/util" | |||
@@ -45,7 +46,7 @@ func filterTextFiles(files []*storage.Object) []*storage.Object { | |||
return res | |||
} | |||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string) (bool, error) { | |||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string, index int, embeddingProviderName string, modelSubType string) (bool, error) { | |||
data, err := queryVectorSafe(embeddingProviderObj, text) | |||
if err != nil { | |||
return false, err | |||
@@ -56,6 +57,11 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st | |||
displayName = text[:25] | |||
} | |||
size, err := model.GetTokenSize(modelSubType, text) | |||
if err != nil { | |||
return false, err | |||
} | |||
vector := &Vector{ | |||
Owner: "admin", | |||
Name: fmt.Sprintf("vector_%s", util.GetRandomName()), | |||
@@ -66,13 +72,14 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st | |||
File: fileName, | |||
Index: index, | |||
Text: text, | |||
Size: size, | |||
Data: data, | |||
Dimension: len(data), | |||
} | |||
return AddVector(vector) | |||
} | |||
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string) (bool, error) { | |||
func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string, embeddingProviderName string, modelSubType string) (bool, error) { | |||
var affected bool | |||
files, err := storageProviderObj.ListObjects(prefix) | |||
@@ -106,7 +113,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro | |||
if timeLimiter.Allow() { | |||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | |||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName) | |||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType) | |||
} else { | |||
err = timeLimiter.Wait(context.Background()) | |||
if err != nil { | |||
@@ -114,7 +121,7 @@ func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingPro | |||
} | |||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | |||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName) | |||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key, i, embeddingProviderName, modelSubType) | |||
} | |||
} | |||
} | |||
@@ -0,0 +1,44 @@ | |||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package object | |||
import ( | |||
"testing" | |||
"github.com/casbin/casibase/model" | |||
) | |||
func TestUpdateVectors(t *testing.T) { | |||
InitConfig() | |||
vectors, err := GetGlobalVectors() | |||
if err != nil { | |||
panic(err) | |||
} | |||
for _, vector := range vectors { | |||
if vector.Text != "" && vector.Size == 0 { | |||
vector.Size, err = model.GetTokenSize("text-davinci-003", vector.Text) | |||
if err != nil { | |||
panic(err) | |||
} | |||
_, err = UpdateVector(vector.GetId(), vector) | |||
if err != nil { | |||
panic(err) | |||
} | |||
} | |||
} | |||
} |
@@ -132,6 +132,16 @@ class VectorEditPage extends React.Component { | |||
}} /> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("vector:Size")}: | |||
</Col> | |||
<Col span={22} > | |||
<InputNumber disabled={true} value={this.state.vector.size} onChange={value => { | |||
this.updateVectorField("size", value); | |||
}} /> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("vector:Dimension")}: | |||
@@ -176,6 +176,13 @@ class VectorListPage extends React.Component { | |||
); | |||
}, | |||
}, | |||
{ | |||
title: i18next.t("vector:Size"), | |||
dataIndex: "size", | |||
key: "size", | |||
width: "80px", | |||
sorter: (a, b) => a.size - b.size, | |||
}, | |||
{ | |||
title: i18next.t("vector:Data"), | |||
dataIndex: "data", | |||