* feat: add refresh store vectors * fix: clean the object.Url from casdoorsdkHEAD
| @@ -0,0 +1,90 @@ | |||||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| package ai | |||||
| import ( | |||||
| "bufio" | |||||
| "context" | |||||
| "fmt" | |||||
| "io" | |||||
| "time" | |||||
| "github.com/sashabaranov/go-openai" | |||||
| ) | |||||
| func splitTxt(f io.ReadCloser) []string { | |||||
| const maxLength = 512 * 3 | |||||
| scanner := bufio.NewScanner(f) | |||||
| var res []string | |||||
| var temp string | |||||
| for scanner.Scan() { | |||||
| line := scanner.Text() | |||||
| if len(temp)+len(line) <= maxLength { | |||||
| temp += line | |||||
| } else { | |||||
| res = append(res, temp) | |||||
| temp = line | |||||
| } | |||||
| } | |||||
| if len(temp) > 0 { | |||||
| res = append(res, temp) | |||||
| } | |||||
| return res | |||||
| } | |||||
| func GetSplitTxt(f io.ReadCloser) []string { | |||||
| return splitTxt(f) | |||||
| } | |||||
| func getEmbedding(authToken string, input []string, timeout int) ([]float32, error) { | |||||
| client := getProxyClientFromToken(authToken) | |||||
| ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) | |||||
| defer cancel() | |||||
| resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |||||
| Input: input, | |||||
| Model: openai.AdaEmbeddingV2, | |||||
| }) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return resp.Data[0].Embedding, nil | |||||
| } | |||||
| func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) { | |||||
| var embedding []float32 | |||||
| var err error | |||||
| for i := 0; i < 10; i++ { | |||||
| embedding, err = getEmbedding(authToken, input, i) | |||||
| if err != nil { | |||||
| if i > 0 { | |||||
| fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error()) | |||||
| } | |||||
| } else { | |||||
| break | |||||
| } | |||||
| } | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } else { | |||||
| return embedding, nil | |||||
| } | |||||
| } | |||||
| @@ -123,3 +123,20 @@ func (c *ApiController) DeleteStore() { | |||||
| c.ResponseOk(sucess) | c.ResponseOk(sucess) | ||||
| } | } | ||||
| func (c *ApiController) RefreshStoreVectors() { | |||||
| var store object.Store | |||||
| err := json.Unmarshal(c.Ctx.Input.RequestBody, &store) | |||||
| if err != nil { | |||||
| c.ResponseError(err.Error()) | |||||
| return | |||||
| } | |||||
| success, err := object.RefreshStoreVectors(&store) | |||||
| if err != nil { | |||||
| c.ResponseError(err.Error()) | |||||
| return | |||||
| } | |||||
| c.ResponseOk(success) | |||||
| } | |||||
| @@ -148,3 +148,20 @@ func DeleteStore(store *Store) (bool, error) { | |||||
| func (store *Store) GetId() string { | func (store *Store) GetId() string { | ||||
| return fmt.Sprintf("%s/%s", store.Owner, store.Name) | return fmt.Sprintf("%s/%s", store.Owner, store.Name) | ||||
| } | } | ||||
| func RefreshStoreVectors(store *Store) (bool, error) { | |||||
| provider, err := getDefaultModelProvider() | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| authToken := provider.ClientSecret | |||||
| success, err := setTxtObjectVector(authToken, store.StorageProvider, "", store.Name) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| if !success { | |||||
| return false, nil | |||||
| } | |||||
| return true, nil | |||||
| } | |||||
| @@ -0,0 +1,137 @@ | |||||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| package object | |||||
| import ( | |||||
| "context" | |||||
| "fmt" | |||||
| "io" | |||||
| "net/http" | |||||
| "strings" | |||||
| "time" | |||||
| "github.com/casbin/casibase/ai" | |||||
| "github.com/casbin/casibase/storage" | |||||
| "github.com/casbin/casibase/util" | |||||
| "golang.org/x/time/rate" | |||||
| ) | |||||
| func isTxt(filename string) bool { | |||||
| return strings.HasSuffix(filename, ".txt") | |||||
| } | |||||
| func filterTxtFiles(files []*storage.Object) []*storage.Object { | |||||
| var res []*storage.Object | |||||
| for _, file := range files { | |||||
| if isTxt(file.Key) { | |||||
| res = append(res, file) | |||||
| } | |||||
| } | |||||
| return res | |||||
| } | |||||
| func getTxtFiles(provider string, prefix string) ([]*storage.Object, error) { | |||||
| files, err := storage.ListObjects(provider, prefix) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return filterTxtFiles(files), nil | |||||
| } | |||||
| func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { | |||||
| resp, err := http.Get(object.Url) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if resp.StatusCode != http.StatusOK { | |||||
| resp.Body.Close() | |||||
| return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode) | |||||
| } | |||||
| return resp.Body, nil | |||||
| } | |||||
| func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | |||||
| embedding, err := ai.GetEmbeddingSafe(authToken, []string{text}) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| displayName := text | |||||
| if len(text) > 25 { | |||||
| displayName = text[:25] | |||||
| } | |||||
| vector := &Vector{ | |||||
| Owner: "admin", | |||||
| Name: fmt.Sprintf("vector_%s", util.GetRandomName()), | |||||
| CreatedTime: util.GetCurrentTime(), | |||||
| DisplayName: displayName, | |||||
| Store: storeName, | |||||
| File: fileName, | |||||
| Text: text, | |||||
| Data: embedding, | |||||
| } | |||||
| return AddVector(vector) | |||||
| } | |||||
| func setTxtObjectVector(authToken string, provider string, key string, storeName string) (bool, error) { | |||||
| lb := rate.NewLimiter(rate.Every(time.Minute), 3) | |||||
| txtObjects, err := getTxtFiles(provider, key) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| if len(txtObjects) == 0 { | |||||
| return false, nil | |||||
| } | |||||
| for _, txtObject := range txtObjects { | |||||
| readCloser, err := getObjectReadCloser(txtObject) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| defer readCloser.Close() | |||||
| splitTxts := ai.GetSplitTxt(readCloser) | |||||
| for _, splitTxt := range splitTxts { | |||||
| if lb.Allow() { | |||||
| success, err := addEmbeddedVector(authToken, splitTxt, storeName, txtObject.Key) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| if !success { | |||||
| return false, nil | |||||
| } | |||||
| } else { | |||||
| err := lb.Wait(context.Background()) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| success, err := addEmbeddedVector(authToken, splitTxt, storeName, txtObject.Key) | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| if !success { | |||||
| return false, nil | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return true, nil | |||||
| } | |||||
| @@ -65,6 +65,7 @@ func initAPI() { | |||||
| beego.Router("/api/update-store", &controllers.ApiController{}, "POST:UpdateStore") | beego.Router("/api/update-store", &controllers.ApiController{}, "POST:UpdateStore") | ||||
| beego.Router("/api/add-store", &controllers.ApiController{}, "POST:AddStore") | beego.Router("/api/add-store", &controllers.ApiController{}, "POST:AddStore") | ||||
| beego.Router("/api/delete-store", &controllers.ApiController{}, "POST:DeleteStore") | beego.Router("/api/delete-store", &controllers.ApiController{}, "POST:DeleteStore") | ||||
| beego.Router("/api/refresh-store-vectors", &controllers.ApiController{}, "POST:RefreshStoreVectors") | |||||
| beego.Router("/api/get-storage-providers", &controllers.ApiController{}, "GET:GetStorageProviders") | beego.Router("/api/get-storage-providers", &controllers.ApiController{}, "GET:GetStorageProviders") | ||||
| @@ -26,6 +26,7 @@ class StoreListPage extends React.Component { | |||||
| this.state = { | this.state = { | ||||
| classes: props, | classes: props, | ||||
| stores: null, | stores: null, | ||||
| generating: false, | |||||
| }; | }; | ||||
| } | } | ||||
| @@ -93,6 +94,23 @@ class StoreListPage extends React.Component { | |||||
| }); | }); | ||||
| } | } | ||||
| refreshStoreVectors(i) { | |||||
| this.setState({generating: true}); | |||||
| StoreBackend.refreshStoreVectors(this.state.stores[i]) | |||||
| .then((res) => { | |||||
| if (res.status === "ok") { | |||||
| Setting.showMessage("success", "Vectors generated successfully"); | |||||
| } else { | |||||
| Setting.showMessage("error", `Vectors failed to generate: ${res.msg}`); | |||||
| } | |||||
| this.setState({generating: false}); | |||||
| }) | |||||
| .catch(error => { | |||||
| Setting.showMessage("error", `Vectors failed to generate: ${error}`); | |||||
| this.setState({generating: false}); | |||||
| }); | |||||
| } | |||||
| renderTable(stores) { | renderTable(stores) { | ||||
| const columns = [ | const columns = [ | ||||
| { | { | ||||
| @@ -135,6 +153,7 @@ class StoreListPage extends React.Component { | |||||
| { | { | ||||
| !Setting.isLocalAdminUser(this.props.account) ? null : ( | !Setting.isLocalAdminUser(this.props.account) ? null : ( | ||||
| <React.Fragment> | <React.Fragment> | ||||
| <Button style={{marginBottom: "10px", marginRight: "10px"}} disabled={this.state.generating} onClick={() => this.refreshStoreVectors(index)}>{i18next.t("store:Refresh Vectors")}</Button> | |||||
| <Button style={{marginBottom: "10px", marginRight: "10px"}} type="primary" onClick={() => this.props.history.push(`/stores/${record.owner}/${record.name}`)}>{i18next.t("general:Edit")}</Button> | <Button style={{marginBottom: "10px", marginRight: "10px"}} type="primary" onClick={() => this.props.history.push(`/stores/${record.owner}/${record.name}`)}>{i18next.t("general:Edit")}</Button> | ||||
| <Popconfirm | <Popconfirm | ||||
| title={`Sure to delete store: ${record.name} ?`} | title={`Sure to delete store: ${record.name} ?`} | ||||
| @@ -145,6 +145,13 @@ class VectorListPage extends React.Component { | |||||
| key: "text", | key: "text", | ||||
| width: "200px", | width: "200px", | ||||
| sorter: (a, b) => a.text.localeCompare(b.text), | sorter: (a, b) => a.text.localeCompare(b.text), | ||||
| render: (text, record, index) => { | |||||
| return ( | |||||
| <div style={{maxWidth: "200px"}}> | |||||
| {Setting.getShortText(text)} | |||||
| </div> | |||||
| ); | |||||
| }, | |||||
| }, | }, | ||||
| { | { | ||||
| title: i18next.t("vector:Data"), | title: i18next.t("vector:Data"), | ||||
| @@ -153,7 +160,11 @@ class VectorListPage extends React.Component { | |||||
| width: "200px", | width: "200px", | ||||
| sorter: (a, b) => a.data.localeCompare(b.data), | sorter: (a, b) => a.data.localeCompare(b.data), | ||||
| render: (text, record, index) => { | render: (text, record, index) => { | ||||
| return JSON.stringify(text); | |||||
| return ( | |||||
| <div style={{maxWidth: "200px"}}> | |||||
| {Setting.getShortText(JSON.stringify(text))} | |||||
| </div> | |||||
| ); | |||||
| }, | }, | ||||
| }, | }, | ||||
| { | { | ||||
| @@ -1,17 +1,17 @@ | |||||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| import * as Setting from "../Setting"; | import * as Setting from "../Setting"; | ||||
| export function getGlobalStores() { | export function getGlobalStores() { | ||||
| @@ -61,3 +61,12 @@ export function deleteStore(store) { | |||||
| body: JSON.stringify(newStore), | body: JSON.stringify(newStore), | ||||
| }).then(res => res.json()); | }).then(res => res.json()); | ||||
| } | } | ||||
| export function refreshStoreVectors(store) { | |||||
| const newStore = Setting.deepCopy(store); | |||||
| return fetch(`${Setting.ServerUrl}/api/refresh-store-vectors`, { | |||||
| method: "POST", | |||||
| credentials: "include", | |||||
| body: JSON.stringify(newStore), | |||||
| }).then(res => res.json()); | |||||
| } | |||||