You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider.go 5.0 kB


  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package object
  15. import (
  16. "fmt"
  17. "github.com/casbin/casibase/embedding"
  18. "github.com/casbin/casibase/model"
  19. "github.com/casbin/casibase/util"
  20. "xorm.io/core"
  21. )
  22. type Provider struct {
  23. Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
  24. Name string `xorm:"varchar(100) notnull pk" json:"name"`
  25. CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
  26. DisplayName string `xorm:"varchar(100)" json:"displayName"`
  27. Category string `xorm:"varchar(100)" json:"category"`
  28. Type string `xorm:"varchar(100)" json:"type"`
  29. SubType string `xorm:"varchar(100)" json:"subType"`
  30. ClientId string `xorm:"varchar(100)" json:"clientId"`
  31. ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"`
  32. ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"`
  33. }
  34. func GetMaskedProvider(provider *Provider, isMaskEnabled bool) *Provider {
  35. if !isMaskEnabled {
  36. return provider
  37. }
  38. if provider == nil {
  39. return nil
  40. }
  41. if provider.ClientSecret != "" {
  42. provider.ClientSecret = "***"
  43. }
  44. return provider
  45. }
  46. func GetMaskedProviders(providers []*Provider, isMaskEnabled bool) []*Provider {
  47. if !isMaskEnabled {
  48. return providers
  49. }
  50. for _, provider := range providers {
  51. provider = GetMaskedProvider(provider, isMaskEnabled)
  52. }
  53. return providers
  54. }
  55. func GetGlobalProviders() ([]*Provider, error) {
  56. providers := []*Provider{}
  57. err := adapter.engine.Asc("owner").Desc("created_time").Find(&providers)
  58. if err != nil {
  59. return providers, err
  60. }
  61. return providers, nil
  62. }
  63. func GetProviders(owner string) ([]*Provider, error) {
  64. providers := []*Provider{}
  65. err := adapter.engine.Desc("created_time").Find(&providers, &Provider{Owner: owner})
  66. if err != nil {
  67. return providers, err
  68. }
  69. return providers, nil
  70. }
  71. func getProvider(owner string, name string) (*Provider, error) {
  72. provider := Provider{Owner: owner, Name: name}
  73. existed, err := adapter.engine.Get(&provider)
  74. if err != nil {
  75. return &provider, err
  76. }
  77. if existed {
  78. return &provider, nil
  79. } else {
  80. return nil, nil
  81. }
  82. }
  83. func GetProvider(id string) (*Provider, error) {
  84. owner, name := util.GetOwnerAndNameFromId(id)
  85. return getProvider(owner, name)
  86. }
  87. func GetDefaultModelProvider() (*Provider, error) {
  88. provider := Provider{Owner: "admin", Category: "Model"}
  89. existed, err := adapter.engine.Get(&provider)
  90. if err != nil {
  91. return &provider, err
  92. }
  93. if !existed {
  94. return nil, nil
  95. }
  96. return &provider, nil
  97. }
  98. func GetDefaultEmbeddingProvider() (*Provider, error) {
  99. provider := Provider{Owner: "admin", Category: "Embedding"}
  100. existed, err := adapter.engine.Get(&provider)
  101. if err != nil {
  102. return &provider, err
  103. }
  104. if !existed {
  105. return nil, nil
  106. }
  107. return &provider, nil
  108. }
  109. func UpdateProvider(id string, provider *Provider) (bool, error) {
  110. owner, name := util.GetOwnerAndNameFromId(id)
  111. p, err := getProvider(owner, name)
  112. if err != nil {
  113. return false, err
  114. }
  115. if provider == nil {
  116. return false, nil
  117. }
  118. if provider.ClientSecret == "***" {
  119. provider.ClientSecret = p.ClientSecret
  120. }
  121. _, err = adapter.engine.ID(core.PK{owner, name}).AllCols().Update(provider)
  122. if err != nil {
  123. return false, err
  124. }
  125. // return affected != 0
  126. return true, nil
  127. }
  128. func AddProvider(provider *Provider) (bool, error) {
  129. affected, err := adapter.engine.Insert(provider)
  130. if err != nil {
  131. return false, err
  132. }
  133. return affected != 0, nil
  134. }
  135. func DeleteProvider(provider *Provider) (bool, error) {
  136. affected, err := adapter.engine.ID(core.PK{provider.Owner, provider.Name}).Delete(&Provider{})
  137. if err != nil {
  138. return false, err
  139. }
  140. return affected != 0, nil
  141. }
  142. func (provider *Provider) GetId() string {
  143. return fmt.Sprintf("%s/%s", provider.Owner, provider.Name)
  144. }
  145. func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
  146. pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret)
  147. if err != nil {
  148. return nil, err
  149. }
  150. if pProvider == nil {
  151. return nil, fmt.Errorf("the model provider type: %s is not supported", p.Type)
  152. }
  153. return pProvider, nil
  154. }
  155. func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) {
  156. pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret)
  157. if err != nil {
  158. return nil, err
  159. }
  160. if pProvider == nil {
  161. return nil, fmt.Errorf("the embedding provider type: %s is not supported", p.Type)
  162. }
  163. return pProvider, nil
  164. }