diff --git a/embedding/cohere.go b/embedding/cohere.go new file mode 100644 index 0000000..1d96035 --- /dev/null +++ b/embedding/cohere.go @@ -0,0 +1,40 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package embedding + +import ( + "context" + + cohereembedder "github.com/henomis/lingoose/embedder/cohere" +) + +type CohereEmbeddingProvider struct { + subType string + secretKey string +} + +func NewCohereEmbeddingProvider(subType string, secretKey string) (*CohereEmbeddingProvider, error) { + return &CohereEmbeddingProvider{subType: subType, secretKey: secretKey}, nil +} + +func (c *CohereEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { + client := cohereembedder.New().WithModel(cohereembedder.EmbedderModel(c.subType)).WithAPIKey(c.secretKey) + embed, err := client.Embed(ctx, []string{text}) + if err != nil { + return nil, err + } + + return float64ToFloat32(embed[0]), nil +} diff --git a/embedding/ernie.go b/embedding/ernie.go new file mode 100644 index 0000000..3713c8b --- /dev/null +++ b/embedding/ernie.go @@ -0,0 +1,42 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package embedding + +import ( + "context" + + ernie "github.com/anhao/go-ernie" +) + +type ErnieEmbeddingProvider struct { + subType string + apiKey string + secretKey string +} + +func NewErnieEmbeddingProvider(subType string, apiKey string, secretKey string) (*ErnieEmbeddingProvider, error) { + return &ErnieEmbeddingProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil +} + +func (e *ErnieEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { + client := ernie.NewDefaultClient(e.apiKey, e.secretKey) + request := ernie.EmbeddingRequest{Input: []string{text}} + embeddings, err := client.CreateEmbeddings(ctx, request) + if err != nil { + return nil, err + } + + return float64ToFloat32(embeddings.Data[0].Embedding), nil +} diff --git a/embedding/huggingface.go b/embedding/huggingface.go new file mode 100644 index 0000000..09c9b39 --- /dev/null +++ b/embedding/huggingface.go @@ -0,0 +1,49 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package embedding + +import ( + "context" + + "github.com/casbin/casibase/proxy" + "github.com/henomis/lingoose/embedder/huggingface" +) + +type HuggingFaceEmbeddingProvider struct { + subType string + secretKey string +} + +func NewHuggingFaceEmbeddingProvider(subType string, secretKey string) (*HuggingFaceEmbeddingProvider, error) { + return &HuggingFaceEmbeddingProvider{subType: subType, secretKey: secretKey}, nil +} + +func (h *HuggingFaceEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { + client := huggingfaceembedder.New().WithToken(h.secretKey).WithModel(h.subType).WithHTTPClient(proxy.ProxyHttpClient) + embed, err := client.Embed(ctx, []string{text}) + if err != nil { + return nil, err + } + + return float64ToFloat32(embed[0]), nil +} + +func float64ToFloat32(slice []float64) []float32 { + newSlice := make([]float32, len(slice)) + for i, v := range slice { + newSlice[i] = float32(v) + } + return newSlice +} diff --git a/embedding/provider.go b/embedding/provider.go index a1a73e0..488fbb5 100644 --- a/embedding/provider.go +++ b/embedding/provider.go @@ -22,11 +22,17 @@ type EmbeddingProvider interface { QueryVector(text string, ctx context.Context) ([]float32, error) } -func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { +func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) { var p EmbeddingProvider var err error if typ == "OpenAI" { p, err = NewOpenAiEmbeddingProvider(subType, clientSecret) + } else if typ == "Hugging Face" { + p, err = NewHuggingFaceEmbeddingProvider(subType, clientSecret) + } else if typ == "Cohere" { + p, err = NewCohereEmbeddingProvider(subType, clientSecret) + } else if typ == "Ernie" { + p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret) } if err != nil { diff --git a/go.mod b/go.mod index 6bb95ff..6bb96a5 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,8 @@ require ( github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect + github.com/henomis/cohere-go v1.0.1 // indirect + github.com/henomis/restclientgo v1.0.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect diff --git a/go.sum b/go.sum index fd30677..b0c225c 100644 --- a/go.sum +++ b/go.sum @@ -278,8 +278,12 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/henomis/cohere-go v1.0.1 h1:a47gIN29tqAl4yBTAT+BzQMjsWG94Fz07u9AE4Md+a8= +github.com/henomis/cohere-go v1.0.1/go.mod h1:F6D33jlWle6pbGdf9Fm2bteaOOQOO1cQtFnlFfj+ZXY= github.com/henomis/lingoose v0.0.11-alpha1 h1:6iXcdewIdTDJCNg7AxZF6onobLEh0BPFyHYTKSV8bAw= github.com/henomis/lingoose v0.0.11-alpha1/go.mod h1:hOfRJswe3sA17uZSUJHJNrBiqPxEt2FM9wUFqFFOHSE= +github.com/henomis/restclientgo v1.0.5 h1:xMuznJLagE8nGrmFPyBkzsDztJm2A7uMLNGMBY5iWSg= +github.com/henomis/restclientgo v1.0.5/go.mod h1:xIeTCu2ZstvRn0fCukNpzXLN3m/kRTU0i0RwAbv7Zug= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/object/provider.go b/object/provider.go index 9459f23..61ab19b 100644 --- a/object/provider.go +++ b/object/provider.go @@ -190,7 +190,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { } func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { - pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret) + pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) if err != nil { return nil, err } diff --git a/web/src/Setting.js b/web/src/Setting.js index d91ac9a..94e9421 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -638,6 +638,9 @@ export function getProviderTypeOptions(category) { return ( [ {id: "OpenAI", name: "OpenAI"}, + {id: "Hugging Face", name: "Hugging Face"}, + {id: "Cohere", name: "Cohere"}, + {id: "Ernie", name: "Ernie"}, ] ); } else { @@ -701,16 +704,26 @@ export function getProviderSubTypeOptions(category, type) { return []; } } else if (type === "Hugging Face") { - return ( - [ - {id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, - {id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, - {id: "bigscience/bloom", name: "bigscience/bloom"}, - {id: "gpt2", name: "gpt2"}, - {id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, - {id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, - ] - ); + if (category === "Model") { + return ( + [ + {id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, + {id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, + {id: "bigscience/bloom", name: "bigscience/bloom"}, + {id: "gpt2", name: "gpt2"}, + {id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, + {id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, + ] + ); + } else if (category === "Embedding") { + return ( + [ + {id: "sentence-transformers/all-MiniLM-L6-v2", name: "sentence-transformers/all-MiniLM-L6-v2"}, + ] + ); + } else { + return []; + } } else if (type === "OpenRouter") { return ( [ @@ -737,12 +750,30 @@ export function getProviderSubTypeOptions(category, type) { ] ); } else if (type === "Ernie") { + if (category === "Model") { + return ( + [ + {id: "ERNIE-Bot", name: "ERNIE-Bot"}, + {id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"}, + {id: "BLOOMZ-7B", name: "BLOOMZ-7B"}, + {id: "Llama-2", name: "Llama-2"}, + ] + ); + } else if (category === "Embedding") { + return ( + [ + {id: "default", name: "default"}, + ] + ); + } else { + return []; + } + } else if (type === "Cohere") { return ( [ - {id: "ERNIE-Bot", name: "ERNIE-Bot"}, - {id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"}, - {id: "BLOOMZ-7B", name: "BLOOMZ-7B"}, - {id: "Llama-2", name: "Llama-2"}, + {id: "embed-english-v2.0", name: "embed-english-v2.0"}, + {id: "embed-english-light-v2.0", name: "embed-english-light-v2.0"}, + {id: "embed-multilingual-v2.0", name: "embed-multilingual-v2.0"}, ] ); } else {