* feat: add refresh store vectors * fix: clean the object.Url from casdoorsdkHEAD
| @@ -0,0 +1,90 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package ai | |||
| import ( | |||
| "bufio" | |||
| "context" | |||
| "fmt" | |||
| "io" | |||
| "time" | |||
| "github.com/sashabaranov/go-openai" | |||
| ) | |||
| func splitTxt(f io.ReadCloser) []string { | |||
| const maxLength = 512 * 3 | |||
| scanner := bufio.NewScanner(f) | |||
| var res []string | |||
| var temp string | |||
| for scanner.Scan() { | |||
| line := scanner.Text() | |||
| if len(temp)+len(line) <= maxLength { | |||
| temp += line | |||
| } else { | |||
| res = append(res, temp) | |||
| temp = line | |||
| } | |||
| } | |||
| if len(temp) > 0 { | |||
| res = append(res, temp) | |||
| } | |||
| return res | |||
| } | |||
| func GetSplitTxt(f io.ReadCloser) []string { | |||
| return splitTxt(f) | |||
| } | |||
| func getEmbedding(authToken string, input []string, timeout int) ([]float32, error) { | |||
| client := getProxyClientFromToken(authToken) | |||
| ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) | |||
| defer cancel() | |||
| resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |||
| Input: input, | |||
| Model: openai.AdaEmbeddingV2, | |||
| }) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return resp.Data[0].Embedding, nil | |||
| } | |||
| func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) { | |||
| var embedding []float32 | |||
| var err error | |||
| for i := 0; i < 10; i++ { | |||
| embedding, err = getEmbedding(authToken, input, i) | |||
| if err != nil { | |||
| if i > 0 { | |||
| fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error()) | |||
| } | |||
| } else { | |||
| break | |||
| } | |||
| } | |||
| if err != nil { | |||
| return nil, err | |||
| } else { | |||
| return embedding, nil | |||
| } | |||
| } | |||
| @@ -123,3 +123,20 @@ func (c *ApiController) DeleteStore() { | |||
| c.ResponseOk(sucess) | |||
| } | |||
| func (c *ApiController) RefreshStoreVectors() { | |||
| var store object.Store | |||
| err := json.Unmarshal(c.Ctx.Input.RequestBody, &store) | |||
| if err != nil { | |||
| c.ResponseError(err.Error()) | |||
| return | |||
| } | |||
| success, err := object.RefreshStoreVectors(&store) | |||
| if err != nil { | |||
| c.ResponseError(err.Error()) | |||
| return | |||
| } | |||
| c.ResponseOk(success) | |||
| } | |||
| @@ -148,3 +148,20 @@ func DeleteStore(store *Store) (bool, error) { | |||
| func (store *Store) GetId() string { | |||
| return fmt.Sprintf("%s/%s", store.Owner, store.Name) | |||
| } | |||
| func RefreshStoreVectors(store *Store) (bool, error) { | |||
| provider, err := getDefaultModelProvider() | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| authToken := provider.ClientSecret | |||
| success, err := setTxtObjectVector(authToken, store.StorageProvider, "", store.Name) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if !success { | |||
| return false, nil | |||
| } | |||
| return true, nil | |||
| } | |||
| @@ -0,0 +1,137 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package object | |||
| import ( | |||
| "context" | |||
| "fmt" | |||
| "io" | |||
| "net/http" | |||
| "strings" | |||
| "time" | |||
| "github.com/casbin/casibase/ai" | |||
| "github.com/casbin/casibase/storage" | |||
| "github.com/casbin/casibase/util" | |||
| "golang.org/x/time/rate" | |||
| ) | |||
| func isTxt(filename string) bool { | |||
| return strings.HasSuffix(filename, ".txt") | |||
| } | |||
| func filterTxtFiles(files []*storage.Object) []*storage.Object { | |||
| var res []*storage.Object | |||
| for _, file := range files { | |||
| if isTxt(file.Key) { | |||
| res = append(res, file) | |||
| } | |||
| } | |||
| return res | |||
| } | |||
| func getTxtFiles(provider string, prefix string) ([]*storage.Object, error) { | |||
| files, err := storage.ListObjects(provider, prefix) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return filterTxtFiles(files), nil | |||
| } | |||
| func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { | |||
| resp, err := http.Get(object.Url) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if resp.StatusCode != http.StatusOK { | |||
| resp.Body.Close() | |||
| return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode) | |||
| } | |||
| return resp.Body, nil | |||
| } | |||
| func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | |||
| embedding, err := ai.GetEmbeddingSafe(authToken, []string{text}) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| displayName := text | |||
| if len(text) > 25 { | |||
| displayName = text[:25] | |||
| } | |||
| vector := &Vector{ | |||
| Owner: "admin", | |||
| Name: fmt.Sprintf("vector_%s", util.GetRandomName()), | |||
| CreatedTime: util.GetCurrentTime(), | |||
| DisplayName: displayName, | |||
| Store: storeName, | |||
| File: fileName, | |||
| Text: text, | |||
| Data: embedding, | |||
| } | |||
| return AddVector(vector) | |||
| } | |||
| func setTxtObjectVector(authToken string, provider string, key string, storeName string) (bool, error) { | |||
| lb := rate.NewLimiter(rate.Every(time.Minute), 3) | |||
| txtObjects, err := getTxtFiles(provider, key) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if len(txtObjects) == 0 { | |||
| return false, nil | |||
| } | |||
| for _, txtObject := range txtObjects { | |||
| readCloser, err := getObjectReadCloser(txtObject) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| defer readCloser.Close() | |||
| splitTxts := ai.GetSplitTxt(readCloser) | |||
| for _, splitTxt := range splitTxts { | |||
| if lb.Allow() { | |||
| success, err := addEmbeddedVector(authToken, splitTxt, storeName, txtObject.Key) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if !success { | |||
| return false, nil | |||
| } | |||
| } else { | |||
| err := lb.Wait(context.Background()) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| success, err := addEmbeddedVector(authToken, splitTxt, storeName, txtObject.Key) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if !success { | |||
| return false, nil | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return true, nil | |||
| } | |||
| @@ -65,6 +65,7 @@ func initAPI() { | |||
| beego.Router("/api/update-store", &controllers.ApiController{}, "POST:UpdateStore") | |||
| beego.Router("/api/add-store", &controllers.ApiController{}, "POST:AddStore") | |||
| beego.Router("/api/delete-store", &controllers.ApiController{}, "POST:DeleteStore") | |||
| beego.Router("/api/refresh-store-vectors", &controllers.ApiController{}, "POST:RefreshStoreVectors") | |||
| beego.Router("/api/get-storage-providers", &controllers.ApiController{}, "GET:GetStorageProviders") | |||
| @@ -26,6 +26,7 @@ class StoreListPage extends React.Component { | |||
| this.state = { | |||
| classes: props, | |||
| stores: null, | |||
| generating: false, | |||
| }; | |||
| } | |||
| @@ -93,6 +94,23 @@ class StoreListPage extends React.Component { | |||
| }); | |||
| } | |||
| refreshStoreVectors(i) { | |||
| this.setState({generating: true}); | |||
| StoreBackend.refreshStoreVectors(this.state.stores[i]) | |||
| .then((res) => { | |||
| if (res.status === "ok") { | |||
| Setting.showMessage("success", "Vectors generated successfully"); | |||
| } else { | |||
| Setting.showMessage("error", `Vectors failed to generate: ${res.msg}`); | |||
| } | |||
| this.setState({generating: false}); | |||
| }) | |||
| .catch(error => { | |||
| Setting.showMessage("error", `Vectors failed to generate: ${error}`); | |||
| this.setState({generating: false}); | |||
| }); | |||
| } | |||
| renderTable(stores) { | |||
| const columns = [ | |||
| { | |||
| @@ -135,6 +153,7 @@ class StoreListPage extends React.Component { | |||
| { | |||
| !Setting.isLocalAdminUser(this.props.account) ? null : ( | |||
| <React.Fragment> | |||
| <Button style={{marginBottom: "10px", marginRight: "10px"}} disabled={this.state.generating} onClick={() => this.refreshStoreVectors(index)}>{i18next.t("store:Refresh Vectors")}</Button> | |||
| <Button style={{marginBottom: "10px", marginRight: "10px"}} type="primary" onClick={() => this.props.history.push(`/stores/${record.owner}/${record.name}`)}>{i18next.t("general:Edit")}</Button> | |||
| <Popconfirm | |||
| title={`Sure to delete store: ${record.name} ?`} | |||
| @@ -145,6 +145,13 @@ class VectorListPage extends React.Component { | |||
| key: "text", | |||
| width: "200px", | |||
| sorter: (a, b) => a.text.localeCompare(b.text), | |||
| render: (text, record, index) => { | |||
| return ( | |||
| <div style={{maxWidth: "200px"}}> | |||
| {Setting.getShortText(text)} | |||
| </div> | |||
| ); | |||
| }, | |||
| }, | |||
| { | |||
| title: i18next.t("vector:Data"), | |||
| @@ -153,7 +160,11 @@ class VectorListPage extends React.Component { | |||
| width: "200px", | |||
| sorter: (a, b) => a.data.localeCompare(b.data), | |||
| render: (text, record, index) => { | |||
| return JSON.stringify(text); | |||
| return ( | |||
| <div style={{maxWidth: "200px"}}> | |||
| {Setting.getShortText(JSON.stringify(text))} | |||
| </div> | |||
| ); | |||
| }, | |||
| }, | |||
| { | |||
| @@ -1,17 +1,17 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| import * as Setting from "../Setting"; | |||
| export function getGlobalStores() { | |||
| @@ -61,3 +61,12 @@ export function deleteStore(store) { | |||
| body: JSON.stringify(newStore), | |||
| }).then(res => res.json()); | |||
| } | |||
| export function refreshStoreVectors(store) { | |||
| const newStore = Setting.deepCopy(store); | |||
| return fetch(`${Setting.ServerUrl}/api/refresh-store-vectors`, { | |||
| method: "POST", | |||
| credentials: "include", | |||
| body: JSON.stringify(newStore), | |||
| }).then(res => res.json()); | |||
| } | |||