@@ -68,46 +68,6 @@ func (c *ApiController) GetMessage() { | |||||
c.ResponseOk(message) | c.ResponseOk(message) | ||||
} | } | ||||
func (c *ApiController) ResponseErrorStream(errorText string) { | |||||
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) | |||||
_, err := c.Ctx.ResponseWriter.Write([]byte(event)) | |||||
if err != nil { | |||||
c.ResponseError(err.Error()) | |||||
return | |||||
} | |||||
} | |||||
func getModelProviderFromContext(owner string, name string) (*object.Provider, error) { | |||||
var providerName string | |||||
if name != "" { | |||||
providerName = name | |||||
} else { | |||||
store, err := object.GetDefaultStore(owner) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if store != nil && store.ModelProvider != "" { | |||||
providerName = store.ModelProvider | |||||
} | |||||
} | |||||
var provider *object.Provider | |||||
var err error | |||||
if providerName != "" { | |||||
providerId := util.GetIdFromOwnerAndName(owner, providerName) | |||||
provider, err = object.GetProvider(providerId) | |||||
} else { | |||||
provider, err = object.GetDefaultModelProvider() | |||||
} | |||||
if provider == nil && err == nil { | |||||
return nil, fmt.Errorf("The provider: %s is not found", providerName) | |||||
} else { | |||||
return provider, err | |||||
} | |||||
} | |||||
func (c *ApiController) GetMessageAnswer() { | func (c *ApiController) GetMessageAnswer() { | ||||
id := c.Input().Get("id") | id := c.Input().Get("id") | ||||
@@ -154,13 +114,15 @@ func (c *ApiController) GetMessageAnswer() { | |||||
return | return | ||||
} | } | ||||
provider, err := getModelProviderFromContext(chat.Owner, chat.User2) | |||||
modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2) | |||||
if err != nil { | if err != nil { | ||||
c.ResponseError(err.Error()) | c.ResponseError(err.Error()) | ||||
return | return | ||||
} | } | ||||
if provider.Category != "Model" || provider.ClientSecret == "" { | |||||
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", provider.GetId())) | |||||
embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2) | |||||
if err != nil { | |||||
c.ResponseError(err.Error()) | |||||
return | return | ||||
} | } | ||||
@@ -168,11 +130,10 @@ func (c *ApiController) GetMessageAnswer() { | |||||
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") | c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") | ||||
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") | c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") | ||||
authToken := provider.ClientSecret | |||||
question := questionMessage.Text | question := questionMessage.Text | ||||
var stringBuilder strings.Builder | var stringBuilder strings.Builder | ||||
nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question) | |||||
nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question) | |||||
if err != nil && err.Error() != "no knowledge vectors found" { | if err != nil && err.Error() != "no knowledge vectors found" { | ||||
c.ResponseErrorStream(err.Error()) | c.ResponseErrorStream(err.Error()) | ||||
return | return | ||||
@@ -184,13 +145,7 @@ func (c *ApiController) GetMessageAnswer() { | |||||
fmt.Printf("Context: [%s]\n", nearestText) | fmt.Printf("Context: [%s]\n", nearestText) | ||||
fmt.Printf("Answer: [") | fmt.Printf("Answer: [") | ||||
modelProvider, err := provider.GetModelProvider() | |||||
if err != nil { | |||||
c.ResponseErrorStream(err.Error()) | |||||
return | |||||
} | |||||
err = modelProvider.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder) | |||||
err = modelProviderObj.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder) | |||||
if err != nil { | if err != nil { | ||||
c.ResponseErrorStream(err.Error()) | c.ResponseErrorStream(err.Error()) | ||||
return | return | ||||
@@ -0,0 +1,111 @@ | |||||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
package controllers | |||||
import ( | |||||
"fmt" | |||||
"github.com/casbin/casibase/embedding" | |||||
"github.com/casbin/casibase/model" | |||||
"github.com/casbin/casibase/object" | |||||
"github.com/casbin/casibase/util" | |||||
) | |||||
func (c *ApiController) ResponseErrorStream(errorText string) { | |||||
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) | |||||
_, err := c.Ctx.ResponseWriter.Write([]byte(event)) | |||||
if err != nil { | |||||
c.ResponseError(err.Error()) | |||||
return | |||||
} | |||||
} | |||||
func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) { | |||||
var providerName string | |||||
if name != "" { | |||||
providerName = name | |||||
} else { | |||||
store, err := object.GetDefaultStore(owner) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if store != nil && store.ModelProvider != "" { | |||||
providerName = store.ModelProvider | |||||
} | |||||
} | |||||
var provider *object.Provider | |||||
var err error | |||||
if providerName != "" { | |||||
providerId := util.GetIdFromOwnerAndName(owner, providerName) | |||||
provider, err = object.GetProvider(providerId) | |||||
} else { | |||||
provider, err = object.GetDefaultModelProvider() | |||||
} | |||||
if provider == nil && err == nil { | |||||
return nil, fmt.Errorf("The model provider: %s is not found", providerName) | |||||
} | |||||
if provider.Category != "Model" || provider.ClientSecret == "" { | |||||
return nil, fmt.Errorf("The model provider: %s is invalid", providerName) | |||||
} | |||||
providerObj, err := provider.GetModelProvider() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return providerObj, err | |||||
} | |||||
func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) { | |||||
var providerName string | |||||
if name != "" { | |||||
providerName = name | |||||
} else { | |||||
store, err := object.GetDefaultStore(owner) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if store != nil && store.EmbeddingProvider != "" { | |||||
providerName = store.EmbeddingProvider | |||||
} | |||||
} | |||||
var provider *object.Provider | |||||
var err error | |||||
if providerName != "" { | |||||
providerId := util.GetIdFromOwnerAndName(owner, providerName) | |||||
provider, err = object.GetProvider(providerId) | |||||
} else { | |||||
provider, err = object.GetDefaultEmbeddingProvider() | |||||
} | |||||
if provider == nil && err == nil { | |||||
return nil, fmt.Errorf("The embedding provider: %s is not found", providerName) | |||||
} | |||||
if provider.Category != "Embedding" || provider.ClientSecret == "" { | |||||
return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName) | |||||
} | |||||
providerObj, err := provider.GetEmbeddingProvider() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return providerObj, err | |||||
} |
@@ -12,13 +12,26 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package model | |||||
package embedding | |||||
import ( | import ( | ||||
"context" | |||||
"time" | |||||
"github.com/casbin/casibase/proxy" | "github.com/casbin/casibase/proxy" | ||||
"github.com/casbin/casibase/util" | |||||
"github.com/sashabaranov/go-openai" | "github.com/sashabaranov/go-openai" | ||||
) | ) | ||||
type OpenAiEmbeddingProvider struct { | |||||
subType string | |||||
secretKey string | |||||
} | |||||
func NewOpenAiEmbeddingProvider(subType string, secretKey string) (*OpenAiEmbeddingProvider, error) { | |||||
return &OpenAiEmbeddingProvider{subType: subType, secretKey: secretKey}, nil | |||||
} | |||||
func getProxyClientFromToken(authToken string) *openai.Client { | func getProxyClientFromToken(authToken string) *openai.Client { | ||||
config := openai.DefaultConfig(authToken) | config := openai.DefaultConfig(authToken) | ||||
config.HTTPClient = proxy.ProxyHttpClient | config.HTTPClient = proxy.ProxyHttpClient | ||||
@@ -26,3 +39,20 @@ func getProxyClientFromToken(authToken string) *openai.Client { | |||||
c := openai.NewClientWithConfig(config) | c := openai.NewClientWithConfig(config) | ||||
return c | return c | ||||
} | } | ||||
func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) { | |||||
client := getProxyClientFromToken(p.secretKey) | |||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) | |||||
defer cancel() | |||||
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ | |||||
Input: []string{text}, | |||||
Model: openai.EmbeddingModel(util.ParseInt(p.subType)), | |||||
}) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return resp.Data[0].Embedding, nil | |||||
} |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package model | |||||
package embedding | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -22,6 +22,23 @@ import ( | |||||
"github.com/sashabaranov/go-openai" | "github.com/sashabaranov/go-openai" | ||||
) | ) | ||||
type EmbeddingProvider interface { | |||||
QueryVector(text string, timeout int) ([]float32, error) | |||||
} | |||||
func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) { | |||||
var p EmbeddingProvider | |||||
var err error | |||||
if typ == "OpenAI" { | |||||
p, err = NewOpenAiEmbeddingProvider(subType, clientSecret) | |||||
} | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return p, nil | |||||
} | |||||
func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { | func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { | ||||
client := getProxyClientFromToken(authToken) | client := getProxyClientFromToken(authToken) | ||||
@@ -21,6 +21,7 @@ import ( | |||||
"net/http" | "net/http" | ||||
"strings" | "strings" | ||||
"github.com/casbin/casibase/proxy" | |||||
"github.com/sashabaranov/go-openai" | "github.com/sashabaranov/go-openai" | ||||
) | ) | ||||
@@ -33,6 +34,14 @@ func NewOpenAiModelProvider(subType string, secretKey string) (*OpenAiModelProvi | |||||
return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil | return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil | ||||
} | } | ||||
func getProxyClientFromToken(authToken string) *openai.Client { | |||||
config := openai.DefaultConfig(authToken) | |||||
config.HTTPClient = proxy.ProxyHttpClient | |||||
c := openai.NewClientWithConfig(config) | |||||
return c | |||||
} | |||||
func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | ||||
client := getProxyClientFromToken(p.secretKey) | client := getProxyClientFromToken(p.secretKey) | ||||
@@ -17,6 +17,7 @@ package object | |||||
import ( | import ( | ||||
"fmt" | "fmt" | ||||
"github.com/casbin/casibase/embedding" | |||||
"github.com/casbin/casibase/model" | "github.com/casbin/casibase/model" | ||||
"github.com/casbin/casibase/util" | "github.com/casbin/casibase/util" | ||||
"xorm.io/core" | "xorm.io/core" | ||||
@@ -116,6 +117,20 @@ func GetDefaultModelProvider() (*Provider, error) { | |||||
return &provider, nil | return &provider, nil | ||||
} | } | ||||
func GetDefaultEmbeddingProvider() (*Provider, error) { | |||||
provider := Provider{Owner: "admin", Category: "Embedding"} | |||||
existed, err := adapter.engine.Get(&provider) | |||||
if err != nil { | |||||
return &provider, err | |||||
} | |||||
if !existed { | |||||
return nil, nil | |||||
} | |||||
return &provider, nil | |||||
} | |||||
func UpdateProvider(id string, provider *Provider) (bool, error) { | func UpdateProvider(id string, provider *Provider) (bool, error) { | ||||
owner, name := util.GetOwnerAndNameFromId(id) | owner, name := util.GetOwnerAndNameFromId(id) | ||||
p, err := getProvider(owner, name) | p, err := getProvider(owner, name) | ||||
@@ -173,3 +188,16 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||||
return pProvider, nil | return pProvider, nil | ||||
} | } | ||||
func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) { | |||||
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if pProvider == nil { | |||||
return nil, fmt.Errorf("the embedding provider type: %s is not supported", p.Type) | |||||
} | |||||
return pProvider, nil | |||||
} |
@@ -44,8 +44,9 @@ type Store struct { | |||||
CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | ||||
DisplayName string `xorm:"varchar(100)" json:"displayName"` | DisplayName string `xorm:"varchar(100)" json:"displayName"` | ||||
StorageProvider string `xorm:"varchar(100)" json:"storageProvider"` | |||||
ModelProvider string `xorm:"varchar(100)" json:"modelProvider"` | |||||
StorageProvider string `xorm:"varchar(100)" json:"storageProvider"` | |||||
ModelProvider string `xorm:"varchar(100)" json:"modelProvider"` | |||||
EmbeddingProvider string `xorm:"varchar(100)" json:"embeddingProvider"` | |||||
FileTree *File `xorm:"mediumtext" json:"fileTree"` | FileTree *File `xorm:"mediumtext" json:"fileTree"` | ||||
PropertiesMap map[string]*Properties `xorm:"mediumtext" json:"propertiesMap"` | PropertiesMap map[string]*Properties `xorm:"mediumtext" json:"propertiesMap"` | ||||
@@ -150,22 +151,26 @@ func (store *Store) GetId() string { | |||||
return fmt.Sprintf("%s/%s", store.Owner, store.Name) | return fmt.Sprintf("%s/%s", store.Owner, store.Name) | ||||
} | } | ||||
func (store *Store) GetModelProvider() (*Provider, error) { | |||||
if store.ModelProvider == "" { | |||||
return GetDefaultModelProvider() | |||||
func (store *Store) GetEmbeddingProvider() (*Provider, error) { | |||||
if store.EmbeddingProvider == "" { | |||||
return GetDefaultEmbeddingProvider() | |||||
} | } | ||||
providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider) | |||||
providerId := util.GetIdFromOwnerAndName(store.Owner, store.EmbeddingProvider) | |||||
return GetProvider(providerId) | return GetProvider(providerId) | ||||
} | } | ||||
func RefreshStoreVectors(store *Store) (bool, error) { | func RefreshStoreVectors(store *Store) (bool, error) { | ||||
provider, err := store.GetModelProvider() | |||||
embeddingProvider, err := store.GetEmbeddingProvider() | |||||
if err != nil { | if err != nil { | ||||
return false, err | return false, err | ||||
} | } | ||||
authToken := provider.ClientSecret | |||||
ok, err := addVectorsForStore(authToken, store.StorageProvider, "", store.Name) | |||||
embeddingProviderObj, err := embeddingProvider.GetEmbeddingProvider() | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
ok, err := addVectorsForStore(embeddingProviderObj, store.StorageProvider, "", store.Name) | |||||
return ok, err | return ok, err | ||||
} | } |
@@ -20,7 +20,7 @@ import ( | |||||
"path/filepath" | "path/filepath" | ||||
"time" | "time" | ||||
"github.com/casbin/casibase/model" | |||||
"github.com/casbin/casibase/embedding" | |||||
"github.com/casbin/casibase/storage" | "github.com/casbin/casibase/storage" | ||||
"github.com/casbin/casibase/txt" | "github.com/casbin/casibase/txt" | ||||
"github.com/casbin/casibase/util" | "github.com/casbin/casibase/util" | ||||
@@ -53,8 +53,9 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, | |||||
return filterTextFiles(files), nil | return filterTextFiles(files), nil | ||||
} | } | ||||
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | |||||
embedding, err := model.GetEmbeddingSafe(authToken, text) | |||||
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) { | |||||
data, err := embeddingProviderObj.QueryVector(text, 5) | |||||
// data, err := model.GetEmbeddingSafe(authToken, text) | |||||
if err != nil { | if err != nil { | ||||
return false, err | return false, err | ||||
} | } | ||||
@@ -72,16 +73,16 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName | |||||
Store: storeName, | Store: storeName, | ||||
File: fileName, | File: fileName, | ||||
Text: text, | Text: text, | ||||
Data: embedding, | |||||
Data: data, | |||||
} | } | ||||
return AddVector(vector) | return AddVector(vector) | ||||
} | } | ||||
func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) { | |||||
func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storageProviderName string, key string, storeName string) (bool, error) { | |||||
var affected bool | var affected bool | ||||
var err error | var err error | ||||
objs, err := getFilteredFileObjects(provider, key) | |||||
objs, err := getFilteredFileObjects(storageProviderName, key) | |||||
if err != nil { | if err != nil { | ||||
return false, err | return false, err | ||||
} | } | ||||
@@ -99,7 +100,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName | |||||
for i, textSection := range textSections { | for i, textSection := range textSections { | ||||
if timeLimiter.Allow() { | if timeLimiter.Allow() { | ||||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | ||||
affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key) | |||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) | |||||
} else { | } else { | ||||
err = timeLimiter.Wait(context.Background()) | err = timeLimiter.Wait(context.Background()) | ||||
if err != nil { | if err != nil { | ||||
@@ -107,7 +108,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName | |||||
} | } | ||||
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) | ||||
affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key) | |||||
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key) | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -127,8 +128,9 @@ func getRelatedVectors(owner string) ([]*Vector, error) { | |||||
return vectors, nil | return vectors, nil | ||||
} | } | ||||
func GetNearestVectorText(authToken string, owner string, question string) (string, error) { | |||||
qVector, err := model.GetEmbeddingSafe(authToken, question) | |||||
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { | |||||
qVector, err := embeddingProvider.QueryVector(text, 5) | |||||
// qVector, err := embedding.GetEmbeddingSafe(authToken, question) | |||||
if err != nil { | if err != nil { | ||||
return "", err | return "", err | ||||
} | } | ||||
@@ -102,8 +102,7 @@ class ProviderEditPage extends React.Component { | |||||
{ | { | ||||
[ | [ | ||||
{id: "Model", name: "Model"}, | {id: "Model", name: "Model"}, | ||||
{id: "Vector Database", name: "Vector Database"}, | |||||
{id: "Storage", name: "Storage"}, | |||||
{id: "Embedding", name: "Embedding"}, | |||||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | ].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | ||||
} | } | ||||
</Select> | </Select> | ||||
@@ -116,11 +115,9 @@ class ProviderEditPage extends React.Component { | |||||
<Col span={22} > | <Col span={22} > | ||||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> | <Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> | ||||
{ | { | ||||
[ | |||||
{id: "OpenAI", name: "OpenAI"}, | |||||
{id: "Hugging Face", name: "Hugging Face"}, | |||||
{id: "Ernie", name: "Ernie"}, | |||||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||||
Setting.getProviderTypeOptions(this.state.provider.category) | |||||
// .sort((a, b) => a.name.localeCompare(b.name)) | |||||
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||||
} | } | ||||
</Select> | </Select> | ||||
</Col> | </Col> | ||||
@@ -132,8 +129,8 @@ class ProviderEditPage extends React.Component { | |||||
<Col span={22} > | <Col span={22} > | ||||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}> | <Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}> | ||||
{ | { | ||||
Setting.getProviderSubTypeOptions(this.state.provider.type) | |||||
.sort((a, b) => a.name.localeCompare(b.name)) | |||||
Setting.getProviderSubTypeOptions(this.state.provider.category, this.state.provider.type) | |||||
// .sort((a, b) => a.name.localeCompare(b.name)) | |||||
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | .map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | ||||
} | } | ||||
</Select> | </Select> | ||||
@@ -117,7 +117,7 @@ class ProviderListPage extends React.Component { | |||||
title: i18next.t("general:Display name"), | title: i18next.t("general:Display name"), | ||||
dataIndex: "displayName", | dataIndex: "displayName", | ||||
key: "displayName", | key: "displayName", | ||||
width: "170px", | |||||
width: "220px", | |||||
sorter: (a, b) => a.displayName.localeCompare(b.displayName), | sorter: (a, b) => a.displayName.localeCompare(b.displayName), | ||||
}, | }, | ||||
{ | { | ||||
@@ -653,35 +653,81 @@ export function isResponseDenied(data) { | |||||
return data.msg === "Unauthorized operation"; | return data.msg === "Unauthorized operation"; | ||||
} | } | ||||
export function getProviderSubTypeOptions(type) { | |||||
if (type === "OpenAI") { | |||||
export function getProviderTypeOptions(category) { | |||||
if (category === "Model") { | |||||
return ( | |||||
[ | |||||
{id: "OpenAI", name: "OpenAI"}, | |||||
{id: "Hugging Face", name: "Hugging Face"}, | |||||
{id: "Ernie", name: "Ernie"}, | |||||
] | |||||
); | |||||
} else if (category === "Embedding") { | |||||
return ( | return ( | ||||
[ | [ | ||||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||||
{id: "gpt-4", name: "gpt-4"}, | |||||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||||
{id: "text-curie-001", name: "text-curie-001"}, | |||||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||||
{id: "text-ada-001", name: "text-ada-001"}, | |||||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||||
{id: "davinci", name: "davinci"}, | |||||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||||
{id: "curie", name: "curie"}, | |||||
{id: "ada", name: "ada"}, | |||||
{id: "babbage", name: "babbage"}, | |||||
{id: "OpenAI", name: "OpenAI"}, | |||||
] | ] | ||||
); | ); | ||||
} else { | |||||
return []; | |||||
} | |||||
} | |||||
export function getProviderSubTypeOptions(category, type) { | |||||
if (type === "OpenAI") { | |||||
if (category === "Model") { | |||||
return ( | |||||
[ | |||||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||||
{id: "gpt-4", name: "gpt-4"}, | |||||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||||
{id: "text-curie-001", name: "text-curie-001"}, | |||||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||||
{id: "text-ada-001", name: "text-ada-001"}, | |||||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||||
{id: "davinci", name: "davinci"}, | |||||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||||
{id: "curie", name: "curie"}, | |||||
{id: "ada", name: "ada"}, | |||||
{id: "babbage", name: "babbage"}, | |||||
] | |||||
); | |||||
} else if (category === "Embedding") { | |||||
return ( | |||||
[ | |||||
{id: "1", name: "AdaSimilarity"}, | |||||
{id: "2", name: "BabbageSimilarity"}, | |||||
{id: "3", name: "CurieSimilarity"}, | |||||
{id: "4", name: "DavinciSimilarity"}, | |||||
{id: "5", name: "AdaSearchDocument"}, | |||||
{id: "6", name: "AdaSearchQuery"}, | |||||
{id: "7", name: "BabbageSearchDocument"}, | |||||
{id: "8", name: "BabbageSearchQuery"}, | |||||
{id: "9", name: "CurieSearchDocument"}, | |||||
{id: "10", name: "CurieSearchQuery"}, | |||||
{id: "11", name: "DavinciSearchDocument"}, | |||||
{id: "12", name: "DavinciSearchQuery"}, | |||||
{id: "13", name: "AdaCodeSearchCode"}, | |||||
{id: "14", name: "AdaCodeSearchText"}, | |||||
{id: "15", name: "BabbageCodeSearchCode"}, | |||||
{id: "16", name: "BabbageCodeSearchText"}, | |||||
{id: "17", name: "AdaEmbeddingV2"}, | |||||
] | |||||
); | |||||
} else { | |||||
return []; | |||||
} | |||||
} else if (type === "Hugging Face") { | } else if (type === "Hugging Face") { | ||||
return ( | return ( | ||||
[ | [ | ||||
@@ -30,6 +30,7 @@ class StoreEditPage extends React.Component { | |||||
storeName: props.match.params.storeName, | storeName: props.match.params.storeName, | ||||
storageProviders: [], | storageProviders: [], | ||||
modelProviders: [], | modelProviders: [], | ||||
embeddingProviders: [], | |||||
store: null, | store: null, | ||||
}; | }; | ||||
} | } | ||||
@@ -37,7 +38,7 @@ class StoreEditPage extends React.Component { | |||||
UNSAFE_componentWillMount() { | UNSAFE_componentWillMount() { | ||||
this.getStore(); | this.getStore(); | ||||
this.getStorageProviders(); | this.getStorageProviders(); | ||||
this.getModelProviders(); | |||||
this.getProviders(); | |||||
} | } | ||||
getStore() { | getStore() { | ||||
@@ -70,12 +71,13 @@ class StoreEditPage extends React.Component { | |||||
}); | }); | ||||
} | } | ||||
getModelProviders() { | |||||
getProviders() { | |||||
ProviderBackend.getProviders(this.props.account.name) | ProviderBackend.getProviders(this.props.account.name) | ||||
.then((res) => { | .then((res) => { | ||||
if (res.status === "ok") { | if (res.status === "ok") { | ||||
this.setState({ | this.setState({ | ||||
modelProviders: res.data.filter(provider => provider.category === "Model"), | modelProviders: res.data.filter(provider => provider.category === "Model"), | ||||
embeddingProviders: res.data.filter(provider => provider.category === "Embedding"), | |||||
}); | }); | ||||
} else { | } else { | ||||
Setting.showMessage("error", `Failed to get providers: ${res.msg}`); | Setting.showMessage("error", `Failed to get providers: ${res.msg}`); | ||||
@@ -148,6 +150,16 @@ class StoreEditPage extends React.Component { | |||||
} /> | } /> | ||||
</Col> | </Col> | ||||
</Row> | </Row> | ||||
<Row style={{marginTop: "20px"}} > | |||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||||
{i18next.t("store:Embedding provider")}: | |||||
</Col> | |||||
<Col span={22} > | |||||
<Select virtual={false} style={{width: "100%"}} value={this.state.store.embeddingProvider} onChange={(value => {this.updateStoreField("embeddingProvider", value);})} | |||||
options={this.state.embeddingProviders.map((provider) => Setting.getOption(`${provider.displayName} (${provider.name})`, `${provider.name}`)) | |||||
} /> | |||||
</Col> | |||||
</Row> | |||||
<Row style={{marginTop: "20px"}} > | <Row style={{marginTop: "20px"}} > | ||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | ||||
{i18next.t("store:File tree")}: | {i18next.t("store:File tree")}: | ||||
@@ -56,6 +56,7 @@ class StoreListPage extends React.Component { | |||||
displayName: `New Store - ${randomName}`, | displayName: `New Store - ${randomName}`, | ||||
storageProvider: "", | storageProvider: "", | ||||
modelProvider: "", | modelProvider: "", | ||||
embeddingProvider: "", | |||||
propertiesMap: {}, | propertiesMap: {}, | ||||
}; | }; | ||||
} | } | ||||
@@ -168,6 +169,20 @@ class StoreListPage extends React.Component { | |||||
); | ); | ||||
}, | }, | ||||
}, | }, | ||||
{ | |||||
title: i18next.t("store:Embedding provider"), | |||||
dataIndex: "embeddingProvider", | |||||
key: "embeddingProvider", | |||||
width: "250px", | |||||
sorter: (a, b) => a.embeddingProvider.localeCompare(b.embeddingProvider), | |||||
render: (text, record, index) => { | |||||
return ( | |||||
<Link to={`/providers/${text}`}> | |||||
{text} | |||||
</Link> | |||||
); | |||||
}, | |||||
}, | |||||
{ | { | ||||
title: i18next.t("general:Action"), | title: i18next.t("general:Action"), | ||||
dataIndex: "action", | dataIndex: "action", | ||||