@@ -0,0 +1,40 @@ | |||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package embedding | |||
import ( | |||
"context" | |||
cohereembedder "github.com/henomis/lingoose/embedder/cohere" | |||
) | |||
type CohereEmbeddingProvider struct { | |||
subType string | |||
secretKey string | |||
} | |||
func NewCohereEmbeddingProvider(subType string, secretKey string) (*CohereEmbeddingProvider, error) { | |||
return &CohereEmbeddingProvider{subType: subType, secretKey: secretKey}, nil | |||
} | |||
func (c *CohereEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||
client := cohereembedder.New().WithModel(cohereembedder.EmbedderModel(c.subType)).WithAPIKey(c.secretKey) | |||
embed, err := client.Embed(ctx, []string{text}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return float64ToFloat32(embed[0]), nil | |||
} |
@@ -0,0 +1,42 @@ | |||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package embedding | |||
import ( | |||
"context" | |||
ernie "github.com/anhao/go-ernie" | |||
) | |||
type ErnieEmbeddingProvider struct { | |||
subType string | |||
apiKey string | |||
secretKey string | |||
} | |||
func NewErnieEmbeddingProvider(subType string, apiKey string, secretKey string) (*ErnieEmbeddingProvider, error) { | |||
return &ErnieEmbeddingProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil | |||
} | |||
func (e *ErnieEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||
client := ernie.NewDefaultClient(e.apiKey, e.secretKey) | |||
request := ernie.EmbeddingRequest{Input: []string{text}} | |||
embeddings, err := client.CreateEmbeddings(ctx, request) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return float64ToFloat32(embeddings.Data[0].Embedding), nil | |||
} |
@@ -0,0 +1,49 @@ | |||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package embedding | |||
import ( | |||
"context" | |||
"github.com/casbin/casibase/proxy" | |||
"github.com/henomis/lingoose/embedder/huggingface" | |||
) | |||
type HuggingFaceEmbeddingProvider struct { | |||
subType string | |||
secretKey string | |||
} | |||
func NewHuggingFaceEmbeddingProvider(subType string, secretKey string) (*HuggingFaceEmbeddingProvider, error) { | |||
return &HuggingFaceEmbeddingProvider{subType: subType, secretKey: secretKey}, nil | |||
} | |||
func (h *HuggingFaceEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) { | |||
client := huggingfaceembedder.New().WithToken(h.secretKey).WithModel(h.subType).WithHTTPClient(proxy.ProxyHttpClient) | |||
embed, err := client.Embed(ctx, []string{text}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return float64ToFloat32(embed[0]), nil | |||
} | |||
func float64ToFloat32(slice []float64) []float32 { | |||
newSlice := make([]float32, len(slice)) | |||
for i, v := range slice { | |||
newSlice[i] = float32(v) | |||
} | |||
return newSlice | |||
} |
@@ -22,11 +22,17 @@ type EmbeddingProvider interface { | |||
QueryVector(text string, ctx context.Context) ([]float32, error) | |||
} | |||
func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { | |||
func GetEmbeddingProvider(typ string, subType string, clientId string, clientSecret string) (EmbeddingProvider, error) { | |||
var p EmbeddingProvider | |||
var err error | |||
if typ == "OpenAI" { | |||
p, err = NewOpenAiEmbeddingProvider(subType, clientSecret) | |||
} else if typ == "Hugging Face" { | |||
p, err = NewHuggingFaceEmbeddingProvider(subType, clientSecret) | |||
} else if typ == "Cohere" { | |||
p, err = NewCohereEmbeddingProvider(subType, clientSecret) | |||
} else if typ == "Ernie" { | |||
p, err = NewErnieEmbeddingProvider(subType, clientId, clientSecret) | |||
} | |||
if err != nil { | |||
@@ -44,6 +44,8 @@ require ( | |||
github.com/gomodule/redigo v2.0.0+incompatible // indirect | |||
github.com/google/go-cmp v0.5.9 // indirect | |||
github.com/hashicorp/golang-lru v0.5.4 // indirect | |||
github.com/henomis/cohere-go v1.0.1 // indirect | |||
github.com/henomis/restclientgo v1.0.5 // indirect | |||
github.com/jmespath/go-jmespath v0.4.0 // indirect | |||
github.com/json-iterator/go v1.1.12 // indirect | |||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect | |||
@@ -278,8 +278,12 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO | |||
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= | |||
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= | |||
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= | |||
github.com/henomis/cohere-go v1.0.1 h1:a47gIN29tqAl4yBTAT+BzQMjsWG94Fz07u9AE4Md+a8= | |||
github.com/henomis/cohere-go v1.0.1/go.mod h1:F6D33jlWle6pbGdf9Fm2bteaOOQOO1cQtFnlFfj+ZXY= | |||
github.com/henomis/lingoose v0.0.11-alpha1 h1:6iXcdewIdTDJCNg7AxZF6onobLEh0BPFyHYTKSV8bAw= | |||
github.com/henomis/lingoose v0.0.11-alpha1/go.mod h1:hOfRJswe3sA17uZSUJHJNrBiqPxEt2FM9wUFqFFOHSE= | |||
github.com/henomis/restclientgo v1.0.5 h1:xMuznJLagE8nGrmFPyBkzsDztJm2A7uMLNGMBY5iWSg= | |||
github.com/henomis/restclientgo v1.0.5/go.mod h1:xIeTCu2ZstvRn0fCukNpzXLN3m/kRTU0i0RwAbv7Zug= | |||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | |||
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= | |||
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= | |||
@@ -190,7 +190,7 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||
} | |||
func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { | |||
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret) | |||
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -638,6 +638,9 @@ export function getProviderTypeOptions(category) { | |||
return ( | |||
[ | |||
{id: "OpenAI", name: "OpenAI"}, | |||
{id: "Hugging Face", name: "Hugging Face"}, | |||
{id: "Cohere", name: "Cohere"}, | |||
{id: "Ernie", name: "Ernie"}, | |||
] | |||
); | |||
} else { | |||
@@ -701,16 +704,26 @@ export function getProviderSubTypeOptions(category, type) { | |||
return []; | |||
} | |||
} else if (type === "Hugging Face") { | |||
return ( | |||
[ | |||
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, | |||
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, | |||
{id: "bigscience/bloom", name: "bigscience/bloom"}, | |||
{id: "gpt2", name: "gpt2"}, | |||
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, | |||
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, | |||
] | |||
); | |||
if (category === "Model") { | |||
return ( | |||
[ | |||
{id: "meta-llama/Llama-2-7b", name: "meta-llama/Llama-2-7b"}, | |||
{id: "tiiuae/falcon-180B", name: "tiiuae/falcon-180B"}, | |||
{id: "bigscience/bloom", name: "bigscience/bloom"}, | |||
{id: "gpt2", name: "gpt2"}, | |||
{id: "baichuan-inc/Baichuan2-13B-Chat", name: "baichuan-inc/Baichuan2-13B-Chat"}, | |||
{id: "THUDM/chatglm2-6b", name: "THUDM/chatglm2-6b"}, | |||
] | |||
); | |||
} else if (category === "Embedding") { | |||
return ( | |||
[ | |||
{id: "sentence-transformers/all-MiniLM-L6-v2", name: "sentence-transformers/all-MiniLM-L6-v2"}, | |||
] | |||
); | |||
} else { | |||
return []; | |||
} | |||
} else if (type === "OpenRouter") { | |||
return ( | |||
[ | |||
@@ -737,12 +750,30 @@ export function getProviderSubTypeOptions(category, type) { | |||
] | |||
); | |||
} else if (type === "Ernie") { | |||
if (category === "Model") { | |||
return ( | |||
[ | |||
{id: "ERNIE-Bot", name: "ERNIE-Bot"}, | |||
{id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"}, | |||
{id: "BLOOMZ-7B", name: "BLOOMZ-7B"}, | |||
{id: "Llama-2", name: "Llama-2"}, | |||
] | |||
); | |||
} else if (category === "Embedding") { | |||
return ( | |||
[ | |||
{id: "default", name: "default"}, | |||
] | |||
); | |||
} else { | |||
return []; | |||
} | |||
} else if (type === "Cohere") { | |||
return ( | |||
[ | |||
{id: "ERNIE-Bot", name: "ERNIE-Bot"}, | |||
{id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"}, | |||
{id: "BLOOMZ-7B", name: "BLOOMZ-7B"}, | |||
{id: "Llama-2", name: "Llama-2"}, | |||
{id: "embed-english-v2.0", name: "embed-english-v2.0"}, | |||
{id: "embed-english-light-v2.0", name: "embed-english-light-v2.0"}, | |||
{id: "embed-multilingual-v2.0", name: "embed-multilingual-v2.0"}, | |||
] | |||
); | |||
} else { | |||