| @@ -0,0 +1,40 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package embedding | |||
| import ( | |||
| "context" | |||
| cohereembedder "github.com/henomis/lingoose/embedder/cohere" | |||
| ) | |||
| type CohereEmbeddingProvider struct { | |||
| subType string | |||
| secretKey string | |||
| } | |||
| func NewCohereEmbeddingProvider(subType string, secretKey string) (*CohereEmbeddingProvider, error) { | |||
| return &CohereEmbeddingProvider{subType: subType, secretKey: secretKey}, nil | |||
| } | |||
| func (c *CohereEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||
| client := cohereembedder.New().WithModel(cohereembedder.EmbedderModel(c.subType)).WithAPIKey(c.secretKey) | |||
| embed, err := client.Embed(ctx, []string{text}) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return float64ToFloat32(embed[0]), nil | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package embedding | |||
| import ( | |||
| "context" | |||
| ernie "github.com/anhao/go-ernie" | |||
| ) | |||
| type ErnieEmbeddingProvider struct { | |||
| subType string | |||
| apiKey string | |||
| secretKey string | |||
| } | |||
| func NewErnieEmbeddingProvider(subType string, apiKey string, secretKey string) (*ErnieEmbeddingProvider, error) { | |||
| return &ErnieEmbeddingProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil | |||
| } | |||
| func (e *ErnieEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||
| client := ernie.NewDefaultClient(e.apiKey, e.secretKey) | |||
| request := ernie.EmbeddingRequest{Input: []string{text}} | |||
| embeddings, err := client.CreateEmbeddings(ctx, request) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return float64ToFloat32(embeddings.Data[0].Embedding), nil | |||
| } | |||
| @@ -0,0 +1,49 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package embedding | |||
| import ( | |||
| "context" | |||
| "github.com/casbin/casibase/proxy" | |||
| "github.com/henomis/lingoose/embedder/huggingface" | |||
| ) | |||
| type HuggingFaceEmbeddingProvider struct { | |||
| subType string | |||
| secretKey string | |||
| } | |||
| func NewHuggingFaceEmbeddingProvider(subType string, secretKey string) (*HuggingFaceEmbeddingProvider, error) { | |||
| return &HuggingFaceEmbeddingProvider{subType: subType, secretKey: secretKey}, nil | |||
| } | |||
| func (h *HuggingFaceEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||
| client := huggingfaceembedder.New().WithToken(h.secretKey).WithModel(h.subType).WithHTTPClient(proxy.ProxyHttpClient) | |||
| embed, err := client.Embed(ctx, []string{text}) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return float64ToFloat32(embed[0]), nil | |||
| } | |||
| func float64ToFloat32(slice []float64) []float32 { | |||
| newSlice := make([]float32, len(slice)) | |||
| for i, v := range slice { | |||
| newSlice[i] = float32(v) | |||
| } | |||
| return newSlice | |||
| } | |||
| @@ -22,11 +22,17 @@ type EmbeddingProvider interface { | |||
| QueryVector(text string, ctx context.Context) ([]float32, error) | |||
| } | |||
| func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { | |||
| func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) { | |||
| var p EmbeddingProvider | |||
| var err error | |||
| if typ == "OpenAI" { | |||
| p, err = NewOpenAiEmbeddingProvider(subType, clientSecret) | |||
| } else if typ == "Hugging Face" { | |||
| p, err = NewHuggingFaceEmbeddingProvider(subType, clientSecret) | |||
| } else if typ == "Cohere" { | |||
| p, err = NewCohereEmbeddingProvider(subType, clientSecret) | |||
| } else if typ == "Ernie" { | |||
| p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret) | |||
| } | |||
| if err != nil { | |||
| @@ -44,6 +44,8 @@ require ( | |||
| github.com/gomodule/redigo v2.0.0+incompatible // indirect | |||
| github.com/google/go-cmp v0.5.9 // indirect | |||
| github.com/hashicorp/golang-lru v0.5.4 // indirect | |||
| github.com/henomis/cohere-go v1.0.1 // indirect | |||
| github.com/henomis/restclientgo v1.0.5 // indirect | |||
| github.com/jmespath/go-jmespath v0.4.0 // indirect | |||
| github.com/json-iterator/go v1.1.12 // indirect | |||
| github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect | |||
| @@ -278,8 +278,12 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO | |||
| github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= | |||
| github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= | |||
| github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= | |||
| github.com/henomis/cohere-go v1.0.1 h1:a47gIN29tqAl4yBTAT+BzQMjsWG94Fz07u9AE4Md+a8= | |||
| github.com/henomis/cohere-go v1.0.1/go.mod h1:F6D33jlWle6pbGdf9Fm2bteaOOQOO1cQtFnlFfj+ZXY= | |||
| github.com/henomis/lingoose v0.0.11-alpha1 h1:6iXcdewIdTDJCNg7AxZF6onobLEh0BPFyHYTKSV8bAw= | |||
| github.com/henomis/lingoose v0.0.11-alpha1/go.mod h1:hOfRJswe3sA17uZSUJHJNrBiqPxEt2FM9wUFqFFOHSE= | |||
| github.com/henomis/restclientgo v1.0.5 h1:xMuznJLagE8nGrmFPyBkzsDztJm2A7uMLNGMBY5iWSg= | |||
| github.com/henomis/restclientgo v1.0.5/go.mod h1:xIeTCu2ZstvRn0fCukNpzXLN3m/kRTU0i0RwAbv7Zug= | |||
| github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | |||
| github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= | |||
| github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= | |||
| @@ -190,7 +190,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||
| } | |||
| func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { | |||
| pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret) | |||
| pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| @@ -638,6 +638,9 @@ export function getProviderTypeOptions(category) { | |||
| return ( | |||
| [ | |||
| {id: "OpenAI", name: "OpenAI"}, | |||
| {id: "Hugging Face", name: "Hugging Face"}, | |||
| {id: "Cohere", name: "Cohere"}, | |||
| {id: "Ernie", name: "Ernie"}, | |||
| ] | |||
| ); | |||
| } else { | |||
| @@ -701,16 +704,26 @@ export function getProviderSubTypeOptions(category, type) { | |||
| return []; | |||
| } | |||
| } else if (type === "Hugging Face") { | |||
| return ( | |||
| [ | |||
| {id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, | |||
| {id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, | |||
| {id: "bigscience/bloom", name: "bigscience/bloom"}, | |||
| {id: "gpt2", name: "gpt2"}, | |||
| {id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, | |||
| {id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, | |||
| ] | |||
| ); | |||
| if (category === "Model") { | |||
| return ( | |||
| [ | |||
| {id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, | |||
| {id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, | |||
| {id: "bigscience/bloom", name: "bigscience/bloom"}, | |||
| {id: "gpt2", name: "gpt2"}, | |||
| {id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, | |||
| {id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, | |||
| ] | |||
| ); | |||
| } else if (category === "Embedding") { | |||
| return ( | |||
| [ | |||
| {id: "sentence-transformers/all-MiniLM-L6-v2", name: "sentence-transformers/all-MiniLM-L6-v2"}, | |||
| ] | |||
| ); | |||
| } else { | |||
| return []; | |||
| } | |||
| } else if (type === "OpenRouter") { | |||
| return ( | |||
| [ | |||
| @@ -737,12 +750,30 @@ export function getProviderSubTypeOptions(category, type) { | |||
| ] | |||
| ); | |||
| } else if (type === "Ernie") { | |||
| if (category === "Model") { | |||
| return ( | |||
| [ | |||
| {id: "ERNIE-Bot", name: "ERNIE-Bot"}, | |||
| {id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"}, | |||
| {id: "BLOOMZ-7B", name: "BLOOMZ-7B"}, | |||
| {id: "Llama-2", name: "Llama-2"}, | |||
| ] | |||
| ); | |||
| } else if (category === "Embedding") { | |||
| return ( | |||
| [ | |||
| {id: "default", name: "default"}, | |||
| ] | |||
| ); | |||
| } else { | |||
| return []; | |||
| } | |||
| } else if (type === "Cohere") { | |||
| return ( | |||
| [ | |||
| {id: "ERNIE-Bot", name: "ERNIE-Bot"}, | |||
| {id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"}, | |||
| {id: "BLOOMZ-7B", name: "BLOOMZ-7B"}, | |||
| {id: "Llama-2", name: "Llama-2"}, | |||
| {id: "embed-english-v2.0", name: "embed-english-v2.0"}, | |||
| {id: "embed-english-light-v2.0", name: "embed-english-light-v2.0"}, | |||
| {id: "embed-multilingual-v2.0", name: "embed-multilingual-v2.0"}, | |||
| ] | |||
| ); | |||
| } else { | |||