Browse Source

feat: support Local embedding using openai format (#664)

master
Kelvin Chiu GitHub 2 years ago
parent
commit
dad89cd383
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 2 deletions
  1. +58
    -0
      embedding/local.go
  2. +3
    -1
      embedding/provider.go
  3. +1
    -1
      object/provider.go
  4. +2
    -0
      web/src/ProviderEditPage.js
  5. +1
    -0
      web/src/Setting.js

+ 58
- 0
embedding/local.go View File

@@ -0,0 +1,58 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package embedding

import (
"context"

"github.com/sashabaranov/go-openai"
)

type LocalEmbeddingProvider struct {
subType string
secretKey string
providerUrl string
}

func NewLocalEmbeddingProvider(subType string, secretKey string, providerUrl string) (*LocalEmbeddingProvider, error) {
p := &LocalEmbeddingProvider{
subType: subType,
secretKey: secretKey,
providerUrl: providerUrl,
}
return p, nil
}

func getLocalClientFromUrl(authToken string, url string) *openai.Client {
config := openai.DefaultConfig(authToken)
config.BaseURL = url

c := openai.NewClientWithConfig(config)
return c
}

func (p *LocalEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
client := getLocalClientFromUrl(p.secretKey, p.providerUrl)

resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{text},
Model: openai.EmbeddingModel(1),
})
if err != nil {
return nil, err
}

return resp.Data[0].Embedding, nil
}

+ 3
- 1
embedding/provider.go View File

@@ -20,7 +20,7 @@ type EmbeddingProvider interface {
QueryVector(text string, ctx context.Context) ([]float32, error)
}

func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) {
func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string, providerUrl string) (EmbeddingProvider, error) {
var p EmbeddingProvider
var err error
if typ == "OpenAI" {
@@ -31,6 +31,8 @@ func GetEmbeddingProvider(typ string, subType string, clientId string, clientSec
p, err = NewCohereEmbeddingProvider(subType, clientSecret)
} else if typ == "Ernie" {
p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret)
} else if typ == "Local" {
p, err = NewLocalEmbeddingProvider(subType, clientSecret, providerUrl)
}

if err != nil {


+ 1
- 1
object/provider.go View File

@@ -224,7 +224,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
}

func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) {
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret)
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.ProviderUrl)
if err != nil {
return nil, err
}


+ 2
- 0
web/src/ProviderEditPage.js View File

@@ -196,6 +196,8 @@ class ProviderEditPage extends React.Component {
this.updateProviderField("subType", "embed-english-v2.0");
} else if (value === "Ernie") {
this.updateProviderField("subType", "default");
} else if (value === "Local") {
this.updateProviderField("subType", "custom-embedding");
}
}
})}>


+ 1
- 0
web/src/Setting.js View File

@@ -657,6 +657,7 @@ export function getProviderTypeOptions(category) {
{id: "Hugging Face", name: "Hugging Face"},
{id: "Cohere", name: "Cohere"},
{id: "Ernie", name: "Ernie"},
{id: "Local", name: "Local"},
]
);
} else {


Loading…
Cancel
Save