You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider.go 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package object
  15. import (
  16. "fmt"
  17. "github.com/casbin/casibase/embedding"
  18. "github.com/casbin/casibase/model"
  19. "github.com/casbin/casibase/storage"
  20. "github.com/casbin/casibase/util"
  21. "xorm.io/core"
  22. )
  23. type Provider struct {
  24. Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
  25. Name string `xorm:"varchar(100) notnull pk" json:"name"`
  26. CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
  27. DisplayName string `xorm:"varchar(100)" json:"displayName"`
  28. Category string `xorm:"varchar(100)" json:"category"`
  29. Type string `xorm:"varchar(100)" json:"type"`
  30. SubType string `xorm:"varchar(100)" json:"subType"`
  31. ClientId string `xorm:"varchar(100)" json:"clientId"`
  32. ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"`
  33. ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"`
  34. ApiVersion string `xorm:"varchar(100)" json:"apiVersion"`
  35. Temperature float32 `xorm:"float" json:"temperature"`
  36. TopP float32 `xorm:"float" json:"topP"`
  37. TopK int `xorm:"int" json:"topK"`
  38. FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"`
  39. PresencePenalty float32 `xorm:"float" json:"presencePenalty"`
  40. }
  41. func GetMaskedProvider(provider *Provider, isMaskEnabled bool) *Provider {
  42. if !isMaskEnabled {
  43. return provider
  44. }
  45. if provider == nil {
  46. return nil
  47. }
  48. if provider.ClientSecret != "" {
  49. provider.ClientSecret = "***"
  50. }
  51. return provider
  52. }
  53. func GetMaskedProviders(providers []*Provider, isMaskEnabled bool) []*Provider {
  54. if !isMaskEnabled {
  55. return providers
  56. }
  57. for _, provider := range providers {
  58. provider = GetMaskedProvider(provider, isMaskEnabled)
  59. }
  60. return providers
  61. }
  62. func GetGlobalProviders() ([]*Provider, error) {
  63. providers := []*Provider{}
  64. err := adapter.engine.Asc("owner").Desc("created_time").Find(&providers)
  65. if err != nil {
  66. return providers, err
  67. }
  68. return providers, nil
  69. }
  70. func GetProviders(owner string) ([]*Provider, error) {
  71. providers := []*Provider{}
  72. err := adapter.engine.Desc("created_time").Find(&providers, &Provider{Owner: owner})
  73. if err != nil {
  74. return providers, err
  75. }
  76. return providers, nil
  77. }
  78. func getProvider(owner string, name string) (*Provider, error) {
  79. provider := Provider{Owner: owner, Name: name}
  80. existed, err := adapter.engine.Get(&provider)
  81. if err != nil {
  82. return &provider, err
  83. }
  84. if existed {
  85. return &provider, nil
  86. } else {
  87. return nil, nil
  88. }
  89. }
  90. func GetProvider(id string) (*Provider, error) {
  91. owner, name := util.GetOwnerAndNameFromId(id)
  92. return getProvider(owner, name)
  93. }
  94. func GetDefaultStorageProvider() (*Provider, error) {
  95. provider := Provider{Owner: "admin", Category: "Storage"}
  96. existed, err := adapter.engine.Get(&provider)
  97. if err != nil {
  98. return &provider, err
  99. }
  100. if !existed {
  101. return nil, nil
  102. }
  103. return &provider, nil
  104. }
  105. func GetDefaultModelProvider() (*Provider, error) {
  106. provider := Provider{Owner: "admin", Category: "Model"}
  107. existed, err := adapter.engine.Get(&provider)
  108. if err != nil {
  109. return &provider, err
  110. }
  111. if !existed {
  112. return nil, nil
  113. }
  114. return &provider, nil
  115. }
  116. func GetDefaultEmbeddingProvider() (*Provider, error) {
  117. provider := Provider{Owner: "admin", Category: "Embedding"}
  118. existed, err := adapter.engine.Get(&provider)
  119. if err != nil {
  120. return &provider, err
  121. }
  122. if !existed {
  123. return nil, nil
  124. }
  125. return &provider, nil
  126. }
  127. func UpdateProvider(id string, provider *Provider) (bool, error) {
  128. owner, name := util.GetOwnerAndNameFromId(id)
  129. p, err := getProvider(owner, name)
  130. if err != nil {
  131. return false, err
  132. }
  133. if provider == nil {
  134. return false, nil
  135. }
  136. if provider.ClientSecret == "***" {
  137. provider.ClientSecret = p.ClientSecret
  138. }
  139. _, err = adapter.engine.ID(core.PK{owner, name}).AllCols().Update(provider)
  140. if err != nil {
  141. return false, err
  142. }
  143. // return affected != 0
  144. return true, nil
  145. }
  146. func AddProvider(provider *Provider) (bool, error) {
  147. affected, err := adapter.engine.Insert(provider)
  148. if err != nil {
  149. return false, err
  150. }
  151. return affected != 0, nil
  152. }
  153. func DeleteProvider(provider *Provider) (bool, error) {
  154. affected, err := adapter.engine.ID(core.PK{provider.Owner, provider.Name}).Delete(&Provider{})
  155. if err != nil {
  156. return false, err
  157. }
  158. return affected != 0, nil
  159. }
  160. func (provider *Provider) GetId() string {
  161. return fmt.Sprintf("%s/%s", provider.Owner, provider.Name)
  162. }
  163. func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) {
  164. pProvider, err := storage.GetStorageProvider(p.Type, p.ClientId, p.Name)
  165. if err != nil {
  166. return nil, err
  167. }
  168. if pProvider == nil {
  169. return nil, fmt.Errorf("the storage provider type: %s is not supported", p.Type)
  170. }
  171. return pProvider, nil
  172. }
  173. func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
  174. pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl, p.ApiVersion)
  175. if err != nil {
  176. return nil, err
  177. }
  178. if pProvider == nil {
  179. return nil, fmt.Errorf("the model provider type: %s is not supported", p.Type)
  180. }
  181. return pProvider, nil
  182. }
  183. func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) {
  184. pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.ProviderUrl)
  185. if err != nil {
  186. return nil, err
  187. }
  188. if pProvider == nil {
  189. return nil, fmt.Errorf("the embedding provider type: %s is not supported", p.Type)
  190. }
  191. return pProvider, nil
  192. }