* feat: add refresh store vectors * fix: clean the object.Url from casdoorsdkHEAD
@@ -0,0 +1,90 @@ | |||||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
package ai | |||||
import ( | |||||
"bufio" | |||||
"context" | |||||
"fmt" | |||||
"io" | |||||
"time" | |||||
"github.com/sashabaranov/go-openai" | |||||
) | |||||
func splitTxt(f io.ReadCloser) []string { | |||||
const maxLength = 512 * 3 | |||||
scanner := bufio.NewScanner(f) | |||||
var res []string | |||||
var temp string | |||||
for scanner.Scan() { | |||||
line := scanner.Text() | |||||
if len(temp)+len(line) <= maxLength { | |||||
temp += line | |||||
} else { | |||||
res = append(res, temp) | |||||
temp = line | |||||
} | |||||
} | |||||
if len(temp) > 0 { | |||||
res = append(res, temp) | |||||
} | |||||
return res | |||||
} | |||||
func GetSplitTxt(f io.ReadCloser) []string { | |||||
return splitTxt(f) | |||||
} | |||||
func getEmbedding(authToken string, input []string, timeout int) ([]float32, error) { | |||||
client := getProxyClientFromToken(authToken) | |||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) | |||||
defer cancel() | |||||
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |||||
Input: input, | |||||
Model: openai.AdaEmbeddingV2, | |||||
}) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return resp.Data[0].Embedding, nil | |||||
} | |||||
func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) { | |||||
var embedding []float32 | |||||
var err error | |||||
for i := 0; i < 10; i++ { | |||||
embedding, err = getEmbedding(authToken, input, i) | |||||
if err != nil { | |||||
if i > 0 { | |||||
fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error()) | |||||
} | |||||
} else { | |||||
break | |||||
} | |||||
} | |||||
if err != nil { | |||||
return nil, err | |||||
} else { | |||||
return embedding, nil | |||||
} | |||||
} |
@@ -123,3 +123,20 @@ func (c *ApiController) DeleteStore() { | |||||
c.ResponseOk(sucess) | c.ResponseOk(sucess) | ||||
} | } | ||||
func (c *ApiController) RefreshStoreVectors() { | |||||
var store object.Store | |||||
err := json.Unmarshal(c.Ctx.Input.RequestBody, &store) | |||||
if err != nil { | |||||
c.ResponseError(err.Error()) | |||||
return | |||||
} | |||||
success, err := object.RefreshStoreVectors(&store) | |||||
if err != nil { | |||||
c.ResponseError(err.Error()) | |||||
return | |||||
} | |||||
c.ResponseOk(success) | |||||
} |
@@ -148,3 +148,20 @@ func DeleteStore(store *Store) (bool, error) { | |||||
func (store *Store) GetId() string { | func (store *Store) GetId() string { | ||||
return fmt.Sprintf("%s/%s", store.Owner, store.Name) | return fmt.Sprintf("%s/%s", store.Owner, store.Name) | ||||
} | } | ||||
func RefreshStoreVectors(store *Store) (bool, error) { | |||||
provider, err := getDefaultModelProvider() | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
authToken := provider.ClientSecret | |||||
success, err := setTxtObjectVector(authToken, store.StorageProvider, "", store.Name) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
if !success { | |||||
return false, nil | |||||
} | |||||
return true, nil | |||||
} |
@@ -0,0 +1,137 @@ | |||||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
package object | |||||
import ( | |||||
"context" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"strings" | |||||
"time" | |||||
"github.com/casbin/casibase/ai" | |||||
"github.com/casbin/casibase/storage" | |||||
"github.com/casbin/casibase/util" | |||||
"golang.org/x/time/rate" | |||||
) | |||||
func isTxt(filename string) bool { | |||||
return strings.HasSuffix(filename, ".txt") | |||||
} | |||||
func filterTxtFiles(files []*storage.Object) []*storage.Object { | |||||
var res []*storage.Object | |||||
for _, file := range files { | |||||
if isTxt(file.Key) { | |||||
res = append(res, file) | |||||
} | |||||
} | |||||
return res | |||||
} | |||||
func getTxtFiles(provider string, prefix string) ([]*storage.Object, error) { | |||||
files, err := storage.ListObjects(provider, prefix) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return filterTxtFiles(files), nil | |||||
} | |||||
func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { | |||||
resp, err := http.Get(object.Url) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if resp.StatusCode != http.StatusOK { | |||||
resp.Body.Close() | |||||
return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode) | |||||
} | |||||
return resp.Body, nil | |||||
} | |||||
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | |||||
embedding, err := ai.GetEmbeddingSafe(authToken, []string{text}) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
displayName := text | |||||
if len(text) > 25 { | |||||
displayName = text[:25] | |||||
} | |||||
vector := &Vector{ | |||||
Owner: "admin", | |||||
Name: fmt.Sprintf("vector_%s", util.GetRandomName()), | |||||
CreatedTime: util.GetCurrentTime(), | |||||
DisplayName: displayName, | |||||
Store: storeName, | |||||
File: fileName, | |||||
Text: text, | |||||
Data: embedding, | |||||
} | |||||
return AddVector(vector) | |||||
} | |||||
func setTxtObjectVector(authToken string, provider string, key string, storeName string) (bool, error) { | |||||
lb := rate.NewLimiter(rate.Every(time.Minute), 3) | |||||
txtObjects, err := getTxtFiles(provider, key) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
if len(txtObjects) == 0 { | |||||
return false, nil | |||||
} | |||||
for _, txtObject := range txtObjects { | |||||
readCloser, err := getObjectReadCloser(txtObject) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
defer readCloser.Close() | |||||
splitTxts := ai.GetSplitTxt(readCloser) | |||||
for _, splitTxt := range splitTxts { | |||||
if lb.Allow() { | |||||
success, err := addEmbeddedVector(authToken, splitTxt, storeName, txtObject.Key) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
if !success { | |||||
return false, nil | |||||
} | |||||
} else { | |||||
err := lb.Wait(context.Background()) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
success, err := addEmbeddedVector(authToken, splitTxt, storeName, txtObject.Key) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
if !success { | |||||
return false, nil | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return true, nil | |||||
} |
@@ -65,6 +65,7 @@ func initAPI() { | |||||
beego.Router("/api/update-store", &controllers.ApiController{}, "POST:UpdateStore") | beego.Router("/api/update-store", &controllers.ApiController{}, "POST:UpdateStore") | ||||
beego.Router("/api/add-store", &controllers.ApiController{}, "POST:AddStore") | beego.Router("/api/add-store", &controllers.ApiController{}, "POST:AddStore") | ||||
beego.Router("/api/delete-store", &controllers.ApiController{}, "POST:DeleteStore") | beego.Router("/api/delete-store", &controllers.ApiController{}, "POST:DeleteStore") | ||||
beego.Router("/api/refresh-store-vectors", &controllers.ApiController{}, "POST:RefreshStoreVectors") | |||||
beego.Router("/api/get-storage-providers", &controllers.ApiController{}, "GET:GetStorageProviders") | beego.Router("/api/get-storage-providers", &controllers.ApiController{}, "GET:GetStorageProviders") | ||||
@@ -26,6 +26,7 @@ class StoreListPage extends React.Component { | |||||
this.state = { | this.state = { | ||||
classes: props, | classes: props, | ||||
stores: null, | stores: null, | ||||
generating: false, | |||||
}; | }; | ||||
} | } | ||||
@@ -93,6 +94,23 @@ class StoreListPage extends React.Component { | |||||
}); | }); | ||||
} | } | ||||
refreshStoreVectors(i) { | |||||
this.setState({generating: true}); | |||||
StoreBackend.refreshStoreVectors(this.state.stores[i]) | |||||
.then((res) => { | |||||
if (res.status === "ok") { | |||||
Setting.showMessage("success", "Vectors generated successfully"); | |||||
} else { | |||||
Setting.showMessage("error", `Vectors failed to generate: ${res.msg}`); | |||||
} | |||||
this.setState({generating: false}); | |||||
}) | |||||
.catch(error => { | |||||
Setting.showMessage("error", `Vectors failed to generate: ${error}`); | |||||
this.setState({generating: false}); | |||||
}); | |||||
} | |||||
renderTable(stores) { | renderTable(stores) { | ||||
const columns = [ | const columns = [ | ||||
{ | { | ||||
@@ -135,6 +153,7 @@ class StoreListPage extends React.Component { | |||||
{ | { | ||||
!Setting.isLocalAdminUser(this.props.account) ? null : ( | !Setting.isLocalAdminUser(this.props.account) ? null : ( | ||||
<React.Fragment> | <React.Fragment> | ||||
<Button style={{marginBottom: "10px", marginRight: "10px"}} disabled={this.state.generating} onClick={() => this.refreshStoreVectors(index)}>{i18next.t("store:Refresh Vectors")}</Button> | |||||
<Button style={{marginBottom: "10px", marginRight: "10px"}} type="primary" onClick={() => this.props.history.push(`/stores/${record.owner}/${record.name}`)}>{i18next.t("general:Edit")}</Button> | <Button style={{marginBottom: "10px", marginRight: "10px"}} type="primary" onClick={() => this.props.history.push(`/stores/${record.owner}/${record.name}`)}>{i18next.t("general:Edit")}</Button> | ||||
<Popconfirm | <Popconfirm | ||||
title={`Sure to delete store: ${record.name} ?`} | title={`Sure to delete store: ${record.name} ?`} | ||||
@@ -145,6 +145,13 @@ class VectorListPage extends React.Component { | |||||
key: "text", | key: "text", | ||||
width: "200px", | width: "200px", | ||||
sorter: (a, b) => a.text.localeCompare(b.text), | sorter: (a, b) => a.text.localeCompare(b.text), | ||||
render: (text, record, index) => { | |||||
return ( | |||||
<div style={{maxWidth: "200px"}}> | |||||
{Setting.getShortText(text)} | |||||
</div> | |||||
); | |||||
}, | |||||
}, | }, | ||||
{ | { | ||||
title: i18next.t("vector:Data"), | title: i18next.t("vector:Data"), | ||||
@@ -153,7 +160,11 @@ class VectorListPage extends React.Component { | |||||
width: "200px", | width: "200px", | ||||
sorter: (a, b) => a.data.localeCompare(b.data), | sorter: (a, b) => a.data.localeCompare(b.data), | ||||
render: (text, record, index) => { | render: (text, record, index) => { | ||||
return JSON.stringify(text); | |||||
return ( | |||||
<div style={{maxWidth: "200px"}}> | |||||
{Setting.getShortText(JSON.stringify(text))} | |||||
</div> | |||||
); | |||||
}, | }, | ||||
}, | }, | ||||
{ | { | ||||
@@ -1,17 +1,17 @@ | |||||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
import * as Setting from "../Setting"; | import * as Setting from "../Setting"; | ||||
export function getGlobalStores() { | export function getGlobalStores() { | ||||
@@ -61,3 +61,12 @@ export function deleteStore(store) { | |||||
body: JSON.stringify(newStore), | body: JSON.stringify(newStore), | ||||
}).then(res => res.json()); | }).then(res => res.json()); | ||||
} | } | ||||
export function refreshStoreVectors(store) { | |||||
const newStore = Setting.deepCopy(store); | |||||
return fetch(`${Setting.ServerUrl}/api/refresh-store-vectors`, { | |||||
method: "POST", | |||||
credentials: "include", | |||||
body: JSON.stringify(newStore), | |||||
}).then(res => res.json()); | |||||
} |