@@ -68,46 +68,6 @@ func (c *ApiController) GetMessage() { | |||
c.ResponseOk(message) | |||
} | |||
func (c *ApiController) ResponseErrorStream(errorText string) { | |||
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) | |||
_, err := c.Ctx.ResponseWriter.Write([]byte(event)) | |||
if err != nil { | |||
c.ResponseError(err.Error()) | |||
return | |||
} | |||
} | |||
func getModelProviderFromContext(owner string, name string) (*object.Provider, error) { | |||
var providerName string | |||
if name != "" { | |||
providerName = name | |||
} else { | |||
store, err := object.GetDefaultStore(owner) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if store != nil && store.ModelProvider != "" { | |||
providerName = store.ModelProvider | |||
} | |||
} | |||
var provider *object.Provider | |||
var err error | |||
if providerName != "" { | |||
providerId := util.GetIdFromOwnerAndName(owner, providerName) | |||
provider, err = object.GetProvider(providerId) | |||
} else { | |||
provider, err = object.GetDefaultModelProvider() | |||
} | |||
if provider == nil && err == nil { | |||
return nil, fmt.Errorf("The provider: %s is not found", providerName) | |||
} else { | |||
return provider, err | |||
} | |||
} | |||
func (c *ApiController) GetMessageAnswer() { | |||
id := c.Input().Get("id") | |||
@@ -154,13 +114,15 @@ func (c *ApiController) GetMessageAnswer() { | |||
return | |||
} | |||
provider, err := getModelProviderFromContext(chat.Owner, chat.User2) | |||
modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2) | |||
if err != nil { | |||
c.ResponseError(err.Error()) | |||
return | |||
} | |||
if provider.Category != "Model" || provider.ClientSecret == "" { | |||
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", provider.GetId())) | |||
embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2) | |||
if err != nil { | |||
c.ResponseError(err.Error()) | |||
return | |||
} | |||
@@ -168,11 +130,10 @@ func (c *ApiController) GetMessageAnswer() { | |||
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") | |||
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") | |||
authToken := provider.ClientSecret | |||
question := questionMessage.Text | |||
var stringBuilder strings.Builder | |||
nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question) | |||
nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question) | |||
if err != nil && err.Error() != "no knowledge vectors found" { | |||
c.ResponseErrorStream(err.Error()) | |||
return | |||
@@ -184,13 +145,7 @@ func (c *ApiController) GetMessageAnswer() { | |||
fmt.Printf("Context: [%s]\n", nearestText) | |||
fmt.Printf("Answer: [") | |||
modelProvider, err := provider.GetModelProvider() | |||
if err != nil { | |||
c.ResponseErrorStream(err.Error()) | |||
return | |||
} | |||
err = modelProvider.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder) | |||
err = modelProviderObj.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder) | |||
if err != nil { | |||
c.ResponseErrorStream(err.Error()) | |||
return | |||
@@ -0,0 +1,111 @@ | |||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package controllers | |||
import ( | |||
"fmt" | |||
"github.com/casbin/casibase/embedding" | |||
"github.com/casbin/casibase/model" | |||
"github.com/casbin/casibase/object" | |||
"github.com/casbin/casibase/util" | |||
) | |||
func (c *ApiController) ResponseErrorStream(errorText string) { | |||
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) | |||
_, err := c.Ctx.ResponseWriter.Write([]byte(event)) | |||
if err != nil { | |||
c.ResponseError(err.Error()) | |||
return | |||
} | |||
} | |||
func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) { | |||
var providerName string | |||
if name != "" { | |||
providerName = name | |||
} else { | |||
store, err := object.GetDefaultStore(owner) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if store != nil && store.ModelProvider != "" { | |||
providerName = store.ModelProvider | |||
} | |||
} | |||
var provider *object.Provider | |||
var err error | |||
if providerName != "" { | |||
providerId := util.GetIdFromOwnerAndName(owner, providerName) | |||
provider, err = object.GetProvider(providerId) | |||
} else { | |||
provider, err = object.GetDefaultModelProvider() | |||
} | |||
if provider == nil && err == nil { | |||
return nil, fmt.Errorf("The model provider: %s is not found", providerName) | |||
} | |||
if provider.Category != "Model" || provider.ClientSecret == "" { | |||
return nil, fmt.Errorf("The model provider: %s is invalid", providerName) | |||
} | |||
providerObj, err := provider.GetModelProvider() | |||
if err != nil { | |||
return nil, err | |||
} | |||
return providerObj, err | |||
} | |||
func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) { | |||
var providerName string | |||
if name != "" { | |||
providerName = name | |||
} else { | |||
store, err := object.GetDefaultStore(owner) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if store != nil && store.EmbeddingProvider != "" { | |||
providerName = store.EmbeddingProvider | |||
} | |||
} | |||
var provider *object.Provider | |||
var err error | |||
if providerName != "" { | |||
providerId := util.GetIdFromOwnerAndName(owner, providerName) | |||
provider, err = object.GetProvider(providerId) | |||
} else { | |||
provider, err = object.GetDefaultEmbeddingProvider() | |||
} | |||
if provider == nil && err == nil { | |||
return nil, fmt.Errorf("The embedding provider: %s is not found", providerName) | |||
} | |||
if provider.Category != "Embedding" || provider.ClientSecret == "" { | |||
return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) | |||
} | |||
providerObj, err := provider.GetEmbeddingProvider() | |||
if err != nil { | |||
return nil, err | |||
} | |||
return providerObj, err | |||
} |
@@ -12,13 +12,26 @@ | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package model | |||
package embedding | |||
import ( | |||
"context" | |||
"time" | |||
"github.com/casbin/casibase/proxy" | |||
"github.com/casbin/casibase/util" | |||
"github.com/sashabaranov/go-openai" | |||
) | |||
type OpenAiEmbeddingProvider struct { | |||
subType string | |||
secretKey string | |||
} | |||
func NewOpenAiEmbeddingProvider(subType string, secretKey string) (*OpenAiEmbeddingProvider, error) { | |||
return &OpenAiEmbeddingProvider{subType: subType, secretKey: secretKey}, nil | |||
} | |||
func getProxyClientFromToken(authToken string) *openai.Client { | |||
config := openai.DefaultConfig(authToken) | |||
config.HTTPClient = proxy.ProxyHttpClient | |||
@@ -26,3 +39,20 @@ func getProxyClientFromToken(authToken string) *openai.Client { | |||
c := openai.NewClientWithConfig(config) | |||
return c | |||
} | |||
func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) { | |||
client := getProxyClientFromToken(p.secretKey) | |||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) | |||
defer cancel() | |||
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |||
Input: []string{text}, | |||
Model: openai.EmbeddingModel(util.ParseInt(p.subType)), | |||
}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return resp.Data[0].Embedding, nil | |||
} |
@@ -12,7 +12,7 @@ | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package model | |||
package embedding | |||
import ( | |||
"context" | |||
@@ -22,6 +22,23 @@ import ( | |||
"github.com/sashabaranov/go-openai" | |||
) | |||
type EmbeddingProvider interface { | |||
QueryVector(text string, timeout int) ([]float32, error) | |||
} | |||
func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { | |||
var p EmbeddingProvider | |||
var err error | |||
if typ == "OpenAI" { | |||
p, err = NewOpenAiEmbeddingProvider(subType, clientSecret) | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
return p, nil | |||
} | |||
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { | |||
client := getProxyClientFromToken(authToken) | |||
@@ -21,6 +21,7 @@ import ( | |||
"net/http" | |||
"strings" | |||
"github.com/casbin/casibase/proxy" | |||
"github.com/sashabaranov/go-openai" | |||
) | |||
@@ -33,6 +34,14 @@ func NewOpenAiModelProvider(subType string, secretKey string) (*OpenAiModelProvi | |||
return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil | |||
} | |||
func getProxyClientFromToken(authToken string) *openai.Client { | |||
config := openai.DefaultConfig(authToken) | |||
config.HTTPClient = proxy.ProxyHttpClient | |||
c := openai.NewClientWithConfig(config) | |||
return c | |||
} | |||
func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | |||
client := getProxyClientFromToken(p.secretKey) | |||
@@ -17,6 +17,7 @@ package object | |||
import ( | |||
"fmt" | |||
"github.com/casbin/casibase/embedding" | |||
"github.com/casbin/casibase/model" | |||
"github.com/casbin/casibase/util" | |||
"xorm.io/core" | |||
@@ -116,6 +117,20 @@ func GetDefaultModelProvider() (*Provider, error) { | |||
return &provider, nil | |||
} | |||
func GetDefaultEmbeddingProvider() (*Provider, error) { | |||
provider := Provider{Owner: "admin", Category: "Embedding"} | |||
existed, err := adapter.engine.Get(&provider) | |||
if err != nil { | |||
return &provider, err | |||
} | |||
if !existed { | |||
return nil, nil | |||
} | |||
return &provider, nil | |||
} | |||
func UpdateProvider(id string, provider *Provider) (bool, error) { | |||
owner, name := util.GetOwnerAndNameFromId(id) | |||
p, err := getProvider(owner, name) | |||
@@ -173,3 +188,16 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||
return pProvider, nil | |||
} | |||
func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { | |||
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if pProvider == nil { | |||
return nil, fmt.Errorf("the embedding provider type: %s is not supported", p.Type) | |||
} | |||
return pProvider, nil | |||
} |
@@ -44,8 +44,9 @@ type Store struct { | |||
CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | |||
DisplayName string `xorm:"varchar(100)" json:"displayName"` | |||
StorageProvider string `xorm:"varchar(100)" json:"storageProvider"` | |||
ModelProvider string `xorm:"varchar(100)" json:"modelProvider"` | |||
StorageProvider string `xorm:"varchar(100)" json:"storageProvider"` | |||
ModelProvider string `xorm:"varchar(100)" json:"modelProvider"` | |||
EmbeddingProvider string `xorm:"varchar(100)" json:"embeddingProvider"` | |||
FileTree *File `xorm:"mediumtext" json:"fileTree"` | |||
PropertiesMap map[string]*Properties `xorm:"mediumtext" json:"propertiesMap"` | |||
@@ -150,22 +151,26 @@ func (store *Store) GetId() string { | |||
return fmt.Sprintf("%s/%s", store.Owner, store.Name) | |||
} | |||
func (store *Store) GetModelProvider() (*Provider, error) { | |||
if store.ModelProvider == "" { | |||
return GetDefaultModelProvider() | |||
func (store *Store) GetEmbeddingProvider() (*Provider, error) { | |||
if store.EmbeddingProvider == "" { | |||
return GetDefaultEmbeddingProvider() | |||
} | |||
providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider) | |||
providerId := util.GetIdFromOwnerAndName(store.Owner, store.EmbeddingProvider) | |||
return GetProvider(providerId) | |||
} | |||
func RefreshStoreVectors(store *Store) (bool, error) { | |||
provider, err := store.GetModelProvider() | |||
embeddingProvider, err := store.GetEmbeddingProvider() | |||
if err != nil { | |||
return false, err | |||
} | |||
authToken := provider.ClientSecret | |||
ok, err := addVectorsForStore(authToken, store.StorageProvider, "", store.Name) | |||
embeddingProviderObj, err := embeddingProvider.GetEmbeddingProvider() | |||
if err != nil { | |||
return false, err | |||
} | |||
ok, err := addVectorsForStore(embeddingProviderObj, store.StorageProvider, "", store.Name) | |||
return ok, err | |||
} |
@@ -20,7 +20,7 @@ import ( | |||
"path/filepath" | |||
"time" | |||
"github.com/casbin/casibase/model" | |||
"github.com/casbin/casibase/embedding" | |||
"github.com/casbin/casibase/storage" | |||
"github.com/casbin/casibase/txt" | |||
"github.com/casbin/casibase/util" | |||
@@ -53,8 +53,9 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, | |||
return filterTextFiles(files), nil | |||
} | |||
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | |||
embedding, err := model.GetEmbeddingSafe(authToken, text) | |||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) { | |||
data, err := embeddingProviderObj.QueryVector(text, 5) | |||
// data, err := model.GetEmbeddingSafe(authToken, text) | |||
if err != nil { | |||
return false, err | |||
} | |||
@@ -72,16 +73,16 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName | |||
Store: storeName, | |||
File: fileName, | |||
Text: text, | |||
Data: embedding, | |||
Data: data, | |||
} | |||
return AddVector(vector) | |||
} | |||
func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) { | |||
func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storageProviderName string, key string, storeName string) (bool, error) { | |||
var affected bool | |||
var err error | |||
objs, err := getFilteredFileObjects(provider, key) | |||
objs, err := getFilteredFileObjects(storageProviderName, key) | |||
if err != nil { | |||
return false, err | |||
} | |||
@@ -99,7 +100,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName | |||
for i, textSection := range textSections { | |||
if timeLimiter.Allow() { | |||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | |||
affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key) | |||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) | |||
} else { | |||
err = timeLimiter.Wait(context.Background()) | |||
if err != nil { | |||
@@ -107,7 +108,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName | |||
} | |||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | |||
affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key) | |||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) | |||
} | |||
} | |||
} | |||
@@ -127,8 +128,9 @@ func getRelatedVectors(owner string) ([]*Vector, error) { | |||
return vectors, nil | |||
} | |||
func GetNearestVectorText(authToken string, owner string, question string) (string, error) { | |||
qVector, err := model.GetEmbeddingSafe(authToken, question) | |||
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { | |||
qVector, err := embeddingProvider.QueryVector(text, 5) | |||
// qVector, err := embedding.GetEmbeddingSafe(authToken, question) | |||
if err != nil { | |||
return "", err | |||
} | |||
@@ -102,8 +102,7 @@ class ProviderEditPage extends React.Component { | |||
{ | |||
[ | |||
{id: "Model", name: "Model"}, | |||
{id: "Vector Database", name: "Vector Database"}, | |||
{id: "Storage", name: "Storage"}, | |||
{id: "Embedding", name: "Embedding"}, | |||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||
} | |||
</Select> | |||
@@ -116,11 +115,9 @@ class ProviderEditPage extends React.Component { | |||
<Col span={22} > | |||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> | |||
{ | |||
[ | |||
{id: "OpenAI", name: "OpenAI"}, | |||
{id: "Hugging Face", name: "Hugging Face"}, | |||
{id: "Ernie", name: "Ernie"}, | |||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||
Setting.getProviderTypeOptions(this.state.provider.category) | |||
// .sort((a, b) => a.name.localeCompare(b.name)) | |||
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||
} | |||
</Select> | |||
</Col> | |||
@@ -132,8 +129,8 @@ class ProviderEditPage extends React.Component { | |||
<Col span={22} > | |||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}> | |||
{ | |||
Setting.getProviderSubTypeOptions(this.state.provider.type) | |||
.sort((a, b) => a.name.localeCompare(b.name)) | |||
Setting.getProviderSubTypeOptions(this.state.provider.category, this.state.provider.type) | |||
// .sort((a, b) => a.name.localeCompare(b.name)) | |||
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||
} | |||
</Select> | |||
@@ -117,7 +117,7 @@ class ProviderListPage extends React.Component { | |||
title: i18next.t("general:Display name"), | |||
dataIndex: "displayName", | |||
key: "displayName", | |||
width: "170px", | |||
width: "220px", | |||
sorter: (a, b) => a.displayName.localeCompare(b.displayName), | |||
}, | |||
{ | |||
@@ -653,35 +653,81 @@ export function isResponseDenied(data) { | |||
return data.msg === "Unauthorized operation"; | |||
} | |||
export function getProviderSubTypeOptions(type) { | |||
if (type === "OpenAI") { | |||
export function getProviderTypeOptions(category) { | |||
if (category === "Model") { | |||
return ( | |||
[ | |||
{id: "OpenAI", name: "OpenAI"}, | |||
{id: "Hugging Face", name: "Hugging Face"}, | |||
{id: "Ernie", name: "Ernie"}, | |||
] | |||
); | |||
} else if (category === "Embedding") { | |||
return ( | |||
[ | |||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
{id: "gpt-4", name: "gpt-4"}, | |||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||
{id: "text-curie-001", name: "text-curie-001"}, | |||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||
{id: "text-ada-001", name: "text-ada-001"}, | |||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
{id: "davinci", name: "davinci"}, | |||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
{id: "curie", name: "curie"}, | |||
{id: "ada", name: "ada"}, | |||
{id: "babbage", name: "babbage"}, | |||
{id: "OpenAI", name: "OpenAI"}, | |||
] | |||
); | |||
} else { | |||
return []; | |||
} | |||
} | |||
export function getProviderSubTypeOptions(category, type) { | |||
if (type === "OpenAI") { | |||
if (category === "Model") { | |||
return ( | |||
[ | |||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
{id: "gpt-4", name: "gpt-4"}, | |||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||
{id: "text-curie-001", name: "text-curie-001"}, | |||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||
{id: "text-ada-001", name: "text-ada-001"}, | |||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
{id: "davinci", name: "davinci"}, | |||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
{id: "curie", name: "curie"}, | |||
{id: "ada", name: "ada"}, | |||
{id: "babbage", name: "babbage"}, | |||
] | |||
); | |||
} else if (category === "Embedding") { | |||
return ( | |||
[ | |||
{id: "1", name: "AdaSimilarity"}, | |||
{id: "2", name: "BabbageSimilarity"}, | |||
{id: "3", name: "CurieSimilarity"}, | |||
{id: "4", name: "DavinciSimilarity"}, | |||
{id: "5", name: "AdaSearchDocument"}, | |||
{id: "6", name: "AdaSearchQuery"}, | |||
{id: "7", name: "BabbageSearchDocument"}, | |||
{id: "8", name: "BabbageSearchQuery"}, | |||
{id: "9", name: "CurieSearchDocument"}, | |||
{id: "10", name: "CurieSearchQuery"}, | |||
{id: "11", name: "DavinciSearchDocument"}, | |||
{id: "12", name: "DavinciSearchQuery"}, | |||
{id: "13", name: "AdaCodeSearchCode"}, | |||
{id: "14", name: "AdaCodeSearchText"}, | |||
{id: "15", name: "BabbageCodeSearchCode"}, | |||
{id: "16", name: "BabbageCodeSearchText"}, | |||
{id: "17", name: "AdaEmbeddingV2"}, | |||
] | |||
); | |||
} else { | |||
return []; | |||
} | |||
} else if (type === "Hugging Face") { | |||
return ( | |||
[ | |||
@@ -30,6 +30,7 @@ class StoreEditPage extends React.Component { | |||
storeName: props.match.params.storeName, | |||
storageProviders: [], | |||
modelProviders: [], | |||
embeddingProviders: [], | |||
store: null, | |||
}; | |||
} | |||
@@ -37,7 +38,7 @@ class StoreEditPage extends React.Component { | |||
UNSAFE_componentWillMount() { | |||
this.getStore(); | |||
this.getStorageProviders(); | |||
this.getModelProviders(); | |||
this.getProviders(); | |||
} | |||
getStore() { | |||
@@ -70,12 +71,13 @@ class StoreEditPage extends React.Component { | |||
}); | |||
} | |||
getModelProviders() { | |||
getProviders() { | |||
ProviderBackend.getProviders(this.props.account.name) | |||
.then((res) => { | |||
if (res.status === "ok") { | |||
this.setState({ | |||
modelProviders: res.data.filter(provider => provider.category === "Model"), | |||
embeddingProviders: res.data.filter(provider => provider.category === "Embedding"), | |||
}); | |||
} else { | |||
Setting.showMessage("error", `Failed to get providers: ${res.msg}`); | |||
@@ -148,6 +150,16 @@ class StoreEditPage extends React.Component { | |||
} /> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("store:Embedding provider")}: | |||
</Col> | |||
<Col span={22} > | |||
<Select virtual={false} style={{width: "100%"}} value={this.state.store.embeddingProvider} onChange={(value => {this.updateStoreField("embeddingProvider", value);})} | |||
options={this.state.embeddingProviders.map((provider) => Setting.getOption(`${provider.displayName} (${provider.name})`, `${provider.name}`)) | |||
} /> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("store:File tree")}: | |||
@@ -56,6 +56,7 @@ class StoreListPage extends React.Component { | |||
displayName: `New Store - ${randomName}`, | |||
storageProvider: "", | |||
modelProvider: "", | |||
embeddingProvider: "", | |||
propertiesMap: {}, | |||
}; | |||
} | |||
@@ -168,6 +169,20 @@ class StoreListPage extends React.Component { | |||
); | |||
}, | |||
}, | |||
{ | |||
title: i18next.t("store:Embedding provider"), | |||
dataIndex: "embeddingProvider", | |||
key: "embeddingProvider", | |||
width: "250px", | |||
sorter: (a, b) => a.embeddingProvider.localeCompare(b.embeddingProvider), | |||
render: (text, record, index) => { | |||
return ( | |||
<Link to={`/providers/${text}`}> | |||
{text} | |||
</Link> | |||
); | |||
}, | |||
}, | |||
{ | |||
title: i18next.t("general:Action"), | |||
dataIndex: "action", | |||