From dad89cd3837bbb200744c680e58abd1de8d7d007 Mon Sep 17 00:00:00 2001 From: Kelvin Chiu Date: Fri, 29 Sep 2023 21:55:19 +0800 Subject: [PATCH] feat: support Local embedding using openai format (#664) --- embedding/local.go | 58 +++++++++++++++++++++++++++++++++++++ embedding/provider.go | 4 ++- object/provider.go | 2 +- web/src/ProviderEditPage.js | 2 ++ web/src/Setting.js | 1 + 5 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 embedding/local.go diff --git a/embedding/local.go b/embedding/local.go new file mode 100644 index 0000000..d1d9b42 --- /dev/null +++ b/embedding/local.go @@ -0,0 +1,58 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package embedding + +import ( + "context" + + "github.com/sashabaranov/go-openai" +) + +type LocalEmbeddingProvider struct { + subType string + secretKey string + providerUrl string +} + +func NewLocalEmbeddingProvider(subType string, secretKey string, providerUrl string) (*LocalEmbeddingProvider, error) { + p := &LocalEmbeddingProvider{ + subType: subType, + secretKey: secretKey, + providerUrl: providerUrl, + } + return p, nil +} + +func getLocalClientFromUrl(authToken string, url string) *openai.Client { + config := openai.DefaultConfig(authToken) + config.BaseURL = url + + c := openai.NewClientWithConfig(config) + return c +} + +func (p *LocalEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { + client := getLocalClientFromUrl(p.secretKey, p.providerUrl) + + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Input: []string{text}, + Model: openai.EmbeddingModel(1), + }) + if err != nil { + return nil, err + } + + return resp.Data[0].Embedding, nil +} diff --git a/embedding/provider.go b/embedding/provider.go index 367316e..d45e586 100644 --- a/embedding/provider.go +++ b/embedding/provider.go @@ -20,7 +20,7 @@ type EmbeddingProvider interface { QueryVector(text string, ctx context.Context) ([]float32, error) } -func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) { +func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string, providerUrl string) (EmbeddingProvider, error) { var p EmbeddingProvider var err error if typ == "OpenAI" { @@ -31,6 +31,8 @@ func GetEmbeddingProvider(typ string, subType string, clientId string, clientSec p, err = NewCohereEmbeddingProvider(subType, clientSecret) } else if typ == "Ernie" { p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret) + } else if typ == "Local" { + p, err = NewLocalEmbeddingProvider(subType, clientSecret, providerUrl) } if err != nil { diff --git a/object/provider.go b/object/provider.go index 2b0aa9b..29958f1 100644 --- a/object/provider.go +++ b/object/provider.go @@ -224,7 +224,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { } func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { - pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) + pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.ProviderUrl) if err != nil { return nil, err } diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index 3c0ba24..4d08aa9 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -196,6 +196,8 @@ class ProviderEditPage extends React.Component { this.updateProviderField("subType", "embed-english-v2.0"); } else if (value === "Ernie") { this.updateProviderField("subType", "default"); + } else if (value === "Local") { + this.updateProviderField("subType", "custom-embedding"); } } })}> diff --git a/web/src/Setting.js b/web/src/Setting.js index 12b0cbc..326cd69 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -657,6 +657,7 @@ export function getProviderTypeOptions(category) { {id: "Hugging Face", name: "Hugging Face"}, {id: "Cohere", name: "Cohere"}, {id: "Ernie", name: "Ernie"}, + {id: "Local", name: "Local"}, ] ); } else {