diff --git a/object/file.go b/object/file.go index f82ea13..5d524fc 100644 --- a/object/file.go +++ b/object/file.go @@ -20,8 +20,6 @@ import ( "io" "mime/multipart" "strings" - - "github.com/casbin/casibase/storage" ) func UpdateFile(storeId string, key string, file *File) bool { @@ -37,6 +35,11 @@ func AddFile(storeId string, userName string, key string, isLeaf bool, filename return false, nil, nil } + storageProviderObj, err := store.GetStorageProviderObj() + if err != nil { + return false, nil, err + } + var objectKey string var fileBuffer *bytes.Buffer if isLeaf { @@ -49,7 +52,7 @@ func AddFile(storeId string, userName string, key string, isLeaf bool, filename } bs := fileBuffer.Bytes() - err = storage.PutObject(store.StorageProvider, userName, store.Name, objectKey, fileBuffer) + err = storageProviderObj.PutObject(userName, store.Name, objectKey, fileBuffer) if err != nil { return false, nil, err } @@ -60,7 +63,7 @@ func AddFile(storeId string, userName string, key string, isLeaf bool, filename objectKey = strings.TrimLeft(objectKey, "/") fileBuffer = bytes.NewBuffer(nil) bs := fileBuffer.Bytes() - err = storage.PutObject(store.StorageProvider, userName, store.Name, objectKey, fileBuffer) + err = storageProviderObj.PutObject(userName, store.Name, objectKey, fileBuffer) if err != nil { return false, nil, err } @@ -78,19 +81,24 @@ func DeleteFile(storeId string, key string, isLeaf bool) (bool, error) { return false, nil } + storageProviderObj, err := store.GetStorageProviderObj() + if err != nil { + return false, err + } + if isLeaf { - err = storage.DeleteObject(store.StorageProvider, key) + err = storageProviderObj.DeleteObject(key) if err != nil { return false, err } } else { - objects, err := storage.ListObjects(store.StorageProvider, key) + objects, err := storageProviderObj.ListObjects(key) if err != nil { return false, err } for _, object := range objects { - err = storage.DeleteObject(store.StorageProvider, object.Key) + err = storageProviderObj.DeleteObject(object.Key) if err != nil { return false, err } diff --git a/object/init.go b/object/init.go index e6938bb..93033c2 100644 --- a/object/init.go +++ b/object/init.go @@ -24,7 +24,7 @@ func InitDb() { } func initBuiltInStore() bool { - store, err := getStore("admin", "built-in") + store, err := getStore("admin", "store-built-in") if err != nil { panic(err) } @@ -35,10 +35,10 @@ func initBuiltInStore() bool { store = &Store{ Owner: "admin", - Name: "store-default", + Name: "store-built-in", CreatedTime: util.GetCurrentTime(), - DisplayName: "Data Store - Default", - StorageProvider: "", + DisplayName: "Built-in Store", + StorageProvider: "provider-storage-built-in", ModelProvider: "", EmbeddingProvider: "", } @@ -51,7 +51,7 @@ func initBuiltInStore() bool { } func initBuiltInProvider() { - provider, err := GetProvider(util.GetId("admin", "provider_captcha_default")) + provider, err := GetProvider(util.GetId("admin", "provider-storage-local-built-in")) if err != nil { panic(err) } @@ -62,11 +62,12 @@ func initBuiltInProvider() { provider = &Provider{ Owner: "admin", - Name: "provider_captcha_default", + Name: "provider-storage-built-in", CreatedTime: util.GetCurrentTime(), - DisplayName: "Captcha Default", - Category: "Captcha", - Type: "Default", + DisplayName: "Built-in Storage Provider", + Category: "Storage", + Type: "Local File System", + ClientId: "F:/github_repos/casdoor-website", } _, err = AddProvider(provider) if err != nil { diff --git a/object/provider.go b/object/provider.go index 61ab19b..e32dcfc 100644 --- a/object/provider.go +++ b/object/provider.go @@ -19,6 +19,7 @@ import ( "github.com/casbin/casibase/embedding" "github.com/casbin/casibase/model" + "github.com/casbin/casibase/storage" "github.com/casbin/casibase/util" "xorm.io/core" ) @@ -103,6 +104,20 @@ func GetProvider(id string) (*Provider, error) { return getProvider(owner, name) } +func GetDefaultStorageProvider() (*Provider, error) { + provider := Provider{Owner: "admin", Category: "Storage"} + existed, err := adapter.engine.Get(&provider) + if err != nil { + return &provider, err + } + + if !existed { + return nil, nil + } + + return &provider, nil +} + func GetDefaultModelProvider() (*Provider, error) { provider := Provider{Owner: "admin", Category: "Model"} existed, err := adapter.engine.Get(&provider) @@ -176,6 +191,19 @@ func (provider *Provider) GetId() string { return fmt.Sprintf("%s/%s", provider.Owner, provider.Name) } +func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { + pProvider, err := storage.GetStorageProvider(p.Type, p.ClientId, p.Name) + if err != nil { + return nil, err + } + + if pProvider == nil { + return nil, fmt.Errorf("the storage provider type: %s is not supported", p.Type) + } + + return pProvider, nil +} + func (p *Provider) GetModelProvider() (model.ModelProvider, error) { pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) if err != nil { diff --git a/object/store.go b/object/store.go index caf5d53..0530361 100644 --- a/object/store.go +++ b/object/store.go @@ -17,6 +17,7 @@ package object import ( "fmt" + "github.com/casbin/casibase/storage" "github.com/casbin/casibase/util" "xorm.io/core" ) @@ -151,6 +152,26 @@ func (store *Store) GetId() string { return fmt.Sprintf("%s/%s", store.Owner, store.Name) } +func (store *Store) GetStorageProviderObj() (storage.StorageProvider, error) { + var provider *Provider + var err error + if store.StorageProvider == "" { + provider, err = GetDefaultStorageProvider() + } else { + providerId := util.GetIdFromOwnerAndName(store.Owner, store.StorageProvider) + provider, err = GetProvider(providerId) + } + if err != nil { + return nil, err + } + + if provider != nil { + return provider.GetStorageProviderObj() + } else { + return storage.NewCasdoorProvider(store.StorageProvider) + } +} + func (store *Store) GetEmbeddingProvider() (*Provider, error) { if store.EmbeddingProvider == "" { return GetDefaultEmbeddingProvider() @@ -161,6 +182,11 @@ func (store *Store) GetEmbeddingProvider() (*Provider, error) { } func RefreshStoreVectors(store *Store) (bool, error) { + storageProviderObj, err := store.GetStorageProviderObj() + if err != nil { + return false, err + } + embeddingProvider, err := store.GetEmbeddingProvider() if err != nil { return false, err @@ -171,6 +197,6 @@ func RefreshStoreVectors(store *Store) (bool, error) { return false, err } - ok, err := addVectorsForStore(embeddingProviderObj, store.StorageProvider, "", store.Name) + ok, err := addVectorsForStore(storageProviderObj, embeddingProviderObj, "", store.Name) return ok, err } diff --git a/object/store_provider.go b/object/store_provider.go index d0130f2..468456f 100644 --- a/object/store_provider.go +++ b/object/store_provider.go @@ -80,7 +80,12 @@ func isObjectLeaf(object *storage.Object) bool { } func (store *Store) Populate() error { - objects, err := storage.ListObjects(store.StorageProvider, "") + storageProviderObj, err := store.GetStorageProviderObj() + if err != nil { + return err + } + + objects, err := storageProviderObj.ListObjects("") if err != nil { return err } @@ -125,7 +130,12 @@ func (store *Store) Populate() error { } func (store *Store) GetVideoData() ([]string, error) { - objects, err := storage.ListObjects(store.StorageProvider, "2023/视频附件") + storageProviderObj, err := store.GetStorageProviderObj() + if err != nil { + return nil, err + } + + objects, err := storageProviderObj.ListObjects("2023/视频附件") if err != nil { return nil, err } diff --git a/object/vector_embedding.go b/object/vector_embedding.go index 17dc540..0192cc7 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -44,15 +44,6 @@ func filterTextFiles(files []*storage.Object) []*storage.Object { return res } -func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, error) { - files, err := storage.ListObjects(provider, prefix) - if err != nil { - return nil, err - } - - return filterTextFiles(files), nil -} - func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) { data, err := queryVectorSafe(embeddingProviderObj, text) if err != nil { @@ -77,20 +68,21 @@ func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text st return AddVector(vector) } -func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storageProviderName string, key string, storeName string) (bool, error) { +func addVectorsForStore(storageProviderObj storage.StorageProvider, embeddingProviderObj embedding.EmbeddingProvider, prefix string, storeName string) (bool, error) { var affected bool - var err error - objs, err := getFilteredFileObjects(storageProviderName, key) + files, err := storageProviderObj.ListObjects(prefix) if err != nil { return false, err } + files = filterTextFiles(files) + timeLimiter := rate.NewLimiter(rate.Every(time.Minute), 3) - for _, obj := range objs { + for _, file := range files { var text string - fileExt := filepath.Ext(obj.Key) - text, err = txt.GetParsedTextFromUrl(obj.Url, fileExt) + fileExt := filepath.Ext(file.Key) + text, err = txt.GetParsedTextFromUrl(file.Url, fileExt) if err != nil { return false, err } @@ -99,7 +91,7 @@ func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storag for i, textSection := range textSections { if timeLimiter.Allow() { fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) - affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) + affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key) } else { err = timeLimiter.Wait(context.Background()) if err != nil { @@ -107,7 +99,7 @@ func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storag } fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) - affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) + affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, file.Key) } } } diff --git a/storage/storage.go b/storage/casdoor.go similarity index 67% rename from storage/storage.go rename to storage/casdoor.go index cd99e22..77de8da 100644 --- a/storage/storage.go +++ b/storage/casdoor.go @@ -22,21 +22,22 @@ import ( "github.com/casdoor/casdoor-go-sdk/casdoorsdk" ) -type Object struct { - Key string - LastModified string - Size int64 - Url string +type CasdoorProvider struct { + providerName string } -func ListObjects(provider string, prefix string) ([]*Object, error) { - if provider == "" { - return nil, fmt.Errorf("storage provider is empty") +func NewCasdoorProvider(providerName string) (*CasdoorProvider, error) { + if providerName == "" { + return nil, fmt.Errorf("storage provider name: [%s] doesn't exist", providerName) } + return &CasdoorProvider{providerName: providerName}, nil +} + +func (p *CasdoorProvider) ListObjects(prefix string) ([]*Object, error) { casdoorOrganization := beego.AppConfig.String("casdoorOrganization") casdoorApplication := beego.AppConfig.String("casdoorApplication") - resources, err := casdoorsdk.GetResources(casdoorOrganization, casdoorApplication, "provider", provider, "Direct", prefix) + resources, err := casdoorsdk.GetResources(casdoorOrganization, casdoorApplication, "provider", p.providerName, "Direct", prefix) if err != nil { return nil, err } @@ -53,24 +54,16 @@ func ListObjects(provider string, prefix string) ([]*Object, error) { return res, nil } -func PutObject(provider string, user string, parent string, key string, fileBuffer *bytes.Buffer) error { - if provider == "" { - return fmt.Errorf("storage provider is empty") - } - - _, _, err := casdoorsdk.UploadResource(user, "Casibase", parent, fmt.Sprintf("Direct/%s/%s", provider, key), fileBuffer.Bytes()) +func (p *CasdoorProvider) PutObject(user string, parent string, key string, fileBuffer *bytes.Buffer) error { + _, _, err := casdoorsdk.UploadResource(user, "Casibase", parent, fmt.Sprintf("Direct/%s/%s", &p.providerName, key), fileBuffer.Bytes()) if err != nil { return err } return nil } -func DeleteObject(provider string, key string) error { - if provider == "" { - return fmt.Errorf("storage provider is empty") - } - - _, err := casdoorsdk.DeleteResource(fmt.Sprintf("Direct/%s/%s", provider, key)) +func (p *CasdoorProvider) DeleteObject(key string) error { + _, err := casdoorsdk.DeleteResource(fmt.Sprintf("Direct/%s/%s", p.providerName, key)) if err != nil { return err } diff --git a/storage/local_file_system.go b/storage/local_file_system.go new file mode 100644 index 0000000..f64799e --- /dev/null +++ b/storage/local_file_system.go @@ -0,0 +1,85 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" +) + +type LocalFileSystemStorageProvider struct { + path string +} + +func NewLocalFileSystemStorageProvider(path string) (*LocalFileSystemStorageProvider, error) { + path = strings.ReplaceAll(path, "\\", "/") + return &LocalFileSystemStorageProvider{path: path}, nil +} + +func (p *LocalFileSystemStorageProvider) ListObjects(prefix string) ([]*Object, error) { + objects := []*Object{} + fullPath := p.path + + filepath.Walk(fullPath, func(path string, info os.FileInfo, err error) error { + if path == fullPath { + return nil + } + + base := filepath.Base(path) + if info.IsDir() && (strings.HasPrefix(base, ".") || base == "node_modules") { + return filepath.SkipDir + } + + if err == nil && !info.IsDir() { + modTime := info.ModTime() + path = strings.ReplaceAll(path, "\\", "/") + relativePath := strings.TrimPrefix(path, fullPath) + relativePath = strings.TrimPrefix(relativePath, "/") + objects = append(objects, &Object{ + Key: relativePath, + LastModified: modTime.Format(time.RFC3339), + Size: info.Size(), + Url: "", + }) + } + return nil + }) + + return objects, nil +} + +func (p *LocalFileSystemStorageProvider) PutObject(user string, parent string, key string, fileBuffer *bytes.Buffer) error { + fullPath := p.path + + err := os.MkdirAll(filepath.Dir(fullPath), os.ModePerm) + if err != nil { + return fmt.Errorf("Casdoor fails to create folder: \"%s\" for local file system storage provider: %s. Make sure Casdoor process has correct permission to create/access it, or you can create it manually in advance", filepath.Dir(fullPath), err.Error()) + } + + dst, err := os.Create(filepath.Clean(fullPath)) + if err == nil { + _, err = io.Copy(dst, fileBuffer) + } + return err +} + +func (p *LocalFileSystemStorageProvider) DeleteObject(key string) error { + return os.Remove(filepath.Join(p.path, key)) +} diff --git a/storage/provider.go b/storage/provider.go new file mode 100644 index 0000000..0c2bfd4 --- /dev/null +++ b/storage/provider.go @@ -0,0 +1,45 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import "bytes" + +type Object struct { + Key string + LastModified string + Size int64 + Url string +} + +type StorageProvider interface { + ListObjects(prefix string) ([]*Object, error) + PutObject(user string, parent string, key string, fileBuffer *bytes.Buffer) error + DeleteObject(key string) error +} + +func GetStorageProvider(typ string, clientId string, providerName string) (StorageProvider, error) { + var p StorageProvider + var err error + if typ == "Local File System" { + p, err = NewLocalFileSystemStorageProvider(clientId) + } else { + p, err = NewCasdoorProvider(providerName) + } + + if err != nil { + return nil, err + } + return p, nil +} diff --git a/storage/storage_test.go b/storage/provider_test.go similarity index 91% rename from storage/storage_test.go rename to storage/provider_test.go index b30b7dc..53a6da5 100644 --- a/storage/storage_test.go +++ b/storage/provider_test.go @@ -31,7 +31,8 @@ func TestStorage(t *testing.T) { controllers.InitAuthConfig() provider := "provider_storage_casibase" - objects, err := storage.ListObjects(provider, "") + providerObj, err := storage.NewCasdoorProvider(provider) + objects, err := providerObj.ListObjects("") if err != nil { panic(err) } diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index de7a49e..fe9deaa 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -101,6 +101,7 @@ class ProviderEditPage extends React.Component { - - - {i18next.t("provider:Sub type")}: - - - - - { - this.state.provider.type !== "Ernie" ? null : ( + this.state.provider.category === "Storage" ? null : ( - {i18next.t("provider:API key")}: + {i18next.t("provider:Sub type")}: + + + + + + ) + } + { + (this.state.provider.type !== "Ernie" && this.state.provider.category !== "Storage") ? null : ( + + + { + (this.state.provider.category !== "Storage") ? i18next.t("provider:API key") : + i18next.t("provider:Path")}: { @@ -150,16 +157,20 @@ class ProviderEditPage extends React.Component { ) } - - - {i18next.t("provider:Secret key")}: - - - { - this.updateProviderField("clientSecret", e.target.value); - }} /> - - + { + this.state.provider.category === "Storage" ? null : ( + + + {i18next.t("provider:Secret key")}: + + + { + this.updateProviderField("clientSecret", e.target.value); + }} /> + + + ) + } {i18next.t("general:Provider URL")}: diff --git a/web/src/ProviderListPage.js b/web/src/ProviderListPage.js index 4614943..dba9b75 100644 --- a/web/src/ProviderListPage.js +++ b/web/src/ProviderListPage.js @@ -103,7 +103,7 @@ class ProviderListPage extends React.Component { title: i18next.t("general:Name"), dataIndex: "name", key: "name", - width: "160px", + width: "180px", sorter: (a, b) => a.name.localeCompare(b.name), render: (text, record, index) => { return ( @@ -131,7 +131,7 @@ class ProviderListPage extends React.Component { title: i18next.t("provider:Type"), dataIndex: "type", key: "type", - width: "120px", + width: "150px", sorter: (a, b) => a.type.localeCompare(b.type), }, { @@ -145,7 +145,7 @@ class ProviderListPage extends React.Component { title: i18next.t("provider:API key"), dataIndex: "clientId", key: "clientId", - width: "160px", + width: "240px", sorter: (a, b) => a.clientId.localeCompare(b.clientId), }, { diff --git a/web/src/Setting.js b/web/src/Setting.js index ba006a9..0bf1815 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -625,7 +625,13 @@ export function isResponseDenied(data) { } export function getProviderTypeOptions(category) { - if (category === "Model") { + if (category === "Storage") { + return ( + [ + {id: "Local File System", name: "Local File System"}, + ] + ); + } else if (category === "Model") { return ( [ {id: "OpenAI", name: "OpenAI"},