| @@ -0,0 +1,58 @@ | |||||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| package embedding | |||||
| import ( | |||||
| "context" | |||||
| "github.com/sashabaranov/go-openai" | |||||
| ) | |||||
| type LocalEmbeddingProvider struct { | |||||
| subType string | |||||
| secretKey string | |||||
| providerUrl string | |||||
| } | |||||
| func NewLocalEmbeddingProvider(subType string, secretKey string, providerUrl string) (*LocalEmbeddingProvider, error) { | |||||
| p := &LocalEmbeddingProvider{ | |||||
| subType: subType, | |||||
| secretKey: secretKey, | |||||
| providerUrl: providerUrl, | |||||
| } | |||||
| return p, nil | |||||
| } | |||||
| func getLocalClientFromUrl(authToken string, url string) *openai.Client { | |||||
| config := openai.DefaultConfig(authToken) | |||||
| config.BaseURL = url | |||||
| c := openai.NewClientWithConfig(config) | |||||
| return c | |||||
| } | |||||
| func (p *LocalEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||||
| client := getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||||
| resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |||||
| Input: []string{text}, | |||||
| Model: openai.EmbeddingModel(1), | |||||
| }) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return resp.Data[0].Embedding, nil | |||||
| } | |||||
| @@ -20,7 +20,7 @@ type EmbeddingProvider interface { | |||||
| QueryVector(text string, ctx context.Context) ([]float32, error) | QueryVector(text string, ctx context.Context) ([]float32, error) | ||||
| } | } | ||||
| func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) { | |||||
| func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string, providerUrl string) (EmbeddingProvider, error) { | |||||
| var p EmbeddingProvider | var p EmbeddingProvider | ||||
| var err error | var err error | ||||
| if typ == "OpenAI" { | if typ == "OpenAI" { | ||||
| @@ -31,6 +31,8 @@ func GetEmbeddingProvider(typ string, subType string, clientId string, clientSec | |||||
| p, err = NewCohereEmbeddingProvider(subType, clientSecret) | p, err = NewCohereEmbeddingProvider(subType, clientSecret) | ||||
| } else if typ == "Ernie" { | } else if typ == "Ernie" { | ||||
| p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret) | p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret) | ||||
| } else if typ == "Local" { | |||||
| p, err = NewLocalEmbeddingProvider(subType, clientSecret, providerUrl) | |||||
| } | } | ||||
| if err != nil { | if err != nil { | ||||
| @@ -224,7 +224,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||||
| } | } | ||||
| func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { | func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { | ||||
| pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||||
| pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.ProviderUrl) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| @@ -196,6 +196,8 @@ class ProviderEditPage extends React.Component { | |||||
| this.updateProviderField("subType", "embed-english-v2.0"); | this.updateProviderField("subType", "embed-english-v2.0"); | ||||
| } else if (value === "Ernie") { | } else if (value === "Ernie") { | ||||
| this.updateProviderField("subType", "default"); | this.updateProviderField("subType", "default"); | ||||
| } else if (value === "Local") { | |||||
| this.updateProviderField("subType", "custom-embedding"); | |||||
| } | } | ||||
| } | } | ||||
| })}> | })}> | ||||
| @@ -657,6 +657,7 @@ export function getProviderTypeOptions(category) { | |||||
| {id: "Hugging Face", name: "Hugging Face"}, | {id: "Hugging Face", name: "Hugging Face"}, | ||||
| {id: "Cohere", name: "Cohere"}, | {id: "Cohere", name: "Cohere"}, | ||||
| {id: "Ernie", name: "Ernie"}, | {id: "Ernie", name: "Ernie"}, | ||||
| {id: "Local", name: "Local"}, | |||||
| ] | ] | ||||
| ); | ); | ||||
| } else { | } else { | ||||