Browse Source

Add embedding provider

HEAD
Yang Luo 2 years ago
parent
commit
2c8f6ac553
14 changed files with 337 additions and 110 deletions
  1. +7
    -52
      controllers/message.go
  2. +111
    -0
      controllers/message_util.go
  3. +31
    -1
      embedding/openai.go
  4. +18
    -1
      embedding/provider.go
  5. +9
    -0
      model/openai.go
  6. +0
    -0
      model/provider.go
  7. +28
    -0
      object/provider.go
  8. +14
    -9
      object/store.go
  9. +12
    -10
      object/vector_embedding.go
  10. +6
    -9
      web/src/ProviderEditPage.js
  11. +1
    -1
      web/src/ProviderListPage.js
  12. +71
    -25
      web/src/Setting.js
  13. +14
    -2
      web/src/StoreEditPage.js
  14. +15
    -0
      web/src/StoreListPage.js

+ 7
- 52
controllers/message.go View File

@@ -68,46 +68,6 @@ func (c *ApiController) GetMessage() {
c.ResponseOk(message) c.ResponseOk(message)
} }


func (c *ApiController) ResponseErrorStream(errorText string) {
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText)
_, err := c.Ctx.ResponseWriter.Write([]byte(event))
if err != nil {
c.ResponseError(err.Error())
return
}
}

func getModelProviderFromContext(owner string, name string) (*object.Provider, error) {
var providerName string
if name != "" {
providerName = name
} else {
store, err := object.GetDefaultStore(owner)
if err != nil {
return nil, err
}

if store != nil && store.ModelProvider != "" {
providerName = store.ModelProvider
}
}

var provider *object.Provider
var err error
if providerName != "" {
providerId := util.GetIdFromOwnerAndName(owner, providerName)
provider, err = object.GetProvider(providerId)
} else {
provider, err = object.GetDefaultModelProvider()
}

if provider == nil && err == nil {
return nil, fmt.Errorf("The provider: %s is not found", providerName)
} else {
return provider, err
}
}

func (c *ApiController) GetMessageAnswer() { func (c *ApiController) GetMessageAnswer() {
id := c.Input().Get("id") id := c.Input().Get("id")


@@ -154,13 +114,15 @@ func (c *ApiController) GetMessageAnswer() {
return return
} }


provider, err := getModelProviderFromContext(chat.Owner, chat.User2)
modelProviderObj, err := getModelProviderFromContext(chat.Owner, chat.User2)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }
if provider.Category != "Model" || provider.ClientSecret == "" {
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", provider.GetId()))

embeddingProviderObj, err := getEmbeddingProviderFromContext(chat.Owner, chat.User2)
if err != nil {
c.ResponseError(err.Error())
return return
} }


@@ -168,11 +130,10 @@ func (c *ApiController) GetMessageAnswer() {
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")


authToken := provider.ClientSecret
question := questionMessage.Text question := questionMessage.Text
var stringBuilder strings.Builder var stringBuilder strings.Builder


nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question)
nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question)
if err != nil && err.Error() != "no knowledge vectors found" { if err != nil && err.Error() != "no knowledge vectors found" {
c.ResponseErrorStream(err.Error()) c.ResponseErrorStream(err.Error())
return return
@@ -184,13 +145,7 @@ func (c *ApiController) GetMessageAnswer() {
fmt.Printf("Context: [%s]\n", nearestText) fmt.Printf("Context: [%s]\n", nearestText)
fmt.Printf("Answer: [") fmt.Printf("Answer: [")


modelProvider, err := provider.GetModelProvider()
if err != nil {
c.ResponseErrorStream(err.Error())
return
}

err = modelProvider.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder)
err = modelProviderObj.QueryText(realQuestion, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil { if err != nil {
c.ResponseErrorStream(err.Error()) c.ResponseErrorStream(err.Error())
return return


+ 111
- 0
controllers/message_util.go View File

@@ -0,0 +1,111 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package controllers

import (
"fmt"

"github.com/casbin/casibase/embedding"
"github.com/casbin/casibase/model"
"github.com/casbin/casibase/object"
"github.com/casbin/casibase/util"
)

func (c *ApiController) ResponseErrorStream(errorText string) {
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText)
_, err := c.Ctx.ResponseWriter.Write([]byte(event))
if err != nil {
c.ResponseError(err.Error())
return
}
}

func getModelProviderFromContext(owner string, name string) (model.ModelProvider, error) {
var providerName string
if name != "" {
providerName = name
} else {
store, err := object.GetDefaultStore(owner)
if err != nil {
return nil, err
}

if store != nil && store.ModelProvider != "" {
providerName = store.ModelProvider
}
}

var provider *object.Provider
var err error
if providerName != "" {
providerId := util.GetIdFromOwnerAndName(owner, providerName)
provider, err = object.GetProvider(providerId)
} else {
provider, err = object.GetDefaultModelProvider()
}

if provider == nil && err == nil {
return nil, fmt.Errorf("The model provider: %s is not found", providerName)
}
if provider.Category != "Model" || provider.ClientSecret == "" {
return nil, fmt.Errorf("The model provider: %s is invalid", providerName)
}

providerObj, err := provider.GetModelProvider()
if err != nil {
return nil, err
}

return providerObj, err
}

func getEmbeddingProviderFromContext(owner string, name string) (embedding.EmbeddingProvider, error) {
var providerName string
if name != "" {
providerName = name
} else {
store, err := object.GetDefaultStore(owner)
if err != nil {
return nil, err
}

if store != nil && store.EmbeddingProvider != "" {
providerName = store.EmbeddingProvider
}
}

var provider *object.Provider
var err error
if providerName != "" {
providerId := util.GetIdFromOwnerAndName(owner, providerName)
provider, err = object.GetProvider(providerId)
} else {
provider, err = object.GetDefaultEmbeddingProvider()
}

if provider == nil && err == nil {
return nil, fmt.Errorf("The embedding provider: %s is not found", providerName)
}
if provider.Category != "Embedding" || provider.ClientSecret == "" {
return nil, fmt.Errorf("The embedding provider: %s is invalid", providerName)
}

providerObj, err := provider.GetEmbeddingProvider()
if err != nil {
return nil, err
}

return providerObj, err
}

model/openai_proxy.go → embedding/openai.go View File

@@ -12,13 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package model
package embedding


import ( import (
"context"
"time"

"github.com/casbin/casibase/proxy" "github.com/casbin/casibase/proxy"
"github.com/casbin/casibase/util"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )


type OpenAiEmbeddingProvider struct {
subType string
secretKey string
}

func NewOpenAiEmbeddingProvider(subType string, secretKey string) (*OpenAiEmbeddingProvider, error) {
return &OpenAiEmbeddingProvider{subType: subType, secretKey: secretKey}, nil
}

func getProxyClientFromToken(authToken string) *openai.Client { func getProxyClientFromToken(authToken string) *openai.Client {
config := openai.DefaultConfig(authToken) config := openai.DefaultConfig(authToken)
config.HTTPClient = proxy.ProxyHttpClient config.HTTPClient = proxy.ProxyHttpClient
@@ -26,3 +39,20 @@ func getProxyClientFromToken(authToken string) *openai.Client {
c := openai.NewClientWithConfig(config) c := openai.NewClientWithConfig(config)
return c return c
} }

func (p *OpenAiEmbeddingProvider) QueryVector(text string, timeout int) ([]float32, error) {
client := getProxyClientFromToken(p.secretKey)

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second)
defer cancel()

resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{text},
Model: openai.EmbeddingModel(util.ParseInt(p.subType)),
})
if err != nil {
return nil, err
}

return resp.Data[0].Embedding, nil
}

model/embedding.go → embedding/provider.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package model
package embedding


import ( import (
"context" "context"
@@ -22,6 +22,23 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )


type EmbeddingProvider interface {
QueryVector(text string, timeout int) ([]float32, error)
}

func GetEmbeddingProvider(typ string, subType string, clientSecret string) (EmbeddingProvider, error) {
var p EmbeddingProvider
var err error
if typ == "OpenAI" {
p, err = NewOpenAiEmbeddingProvider(subType, clientSecret)
}

if err != nil {
return nil, err
}
return p, nil
}

func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { func getEmbedding(authToken string, text string, timeout int) ([]float32, error) {
client := getProxyClientFromToken(authToken) client := getProxyClientFromToken(authToken)



+ 9
- 0
model/openai.go View File

@@ -21,6 +21,7 @@ import (
"net/http" "net/http"
"strings" "strings"


"github.com/casbin/casibase/proxy"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )


@@ -33,6 +34,14 @@ func NewOpenAiModelProvider(subType string, secretKey string) (*OpenAiModelProvi
return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil
} }


func getProxyClientFromToken(authToken string) *openai.Client {
config := openai.DefaultConfig(authToken)
config.HTTPClient = proxy.ProxyHttpClient

c := openai.NewClientWithConfig(config)
return c
}

func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
client := getProxyClientFromToken(p.secretKey) client := getProxyClientFromToken(p.secretKey)




model/model.go → model/provider.go View File


+ 28
- 0
object/provider.go View File

@@ -17,6 +17,7 @@ package object
import ( import (
"fmt" "fmt"


"github.com/casbin/casibase/embedding"
"github.com/casbin/casibase/model" "github.com/casbin/casibase/model"
"github.com/casbin/casibase/util" "github.com/casbin/casibase/util"
"xorm.io/core" "xorm.io/core"
@@ -116,6 +117,20 @@ func GetDefaultModelProvider() (*Provider, error) {
return &provider, nil return &provider, nil
} }


func GetDefaultEmbeddingProvider() (*Provider, error) {
provider := Provider{Owner: "admin", Category: "Embedding"}
existed, err := adapter.engine.Get(&provider)
if err != nil {
return &provider, err
}

if !existed {
return nil, nil
}

return &provider, nil
}

func UpdateProvider(id string, provider *Provider) (bool, error) { func UpdateProvider(id string, provider *Provider) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
p, err := getProvider(owner, name) p, err := getProvider(owner, name)
@@ -173,3 +188,16 @@ func (p *Provider) GetModelProvider() (model.ModelProvider, error) {


return pProvider, nil return pProvider, nil
} }

func (p *Provider) GetEmbeddingProvider() (embedding.EmbeddingProvider, error) {
pProvider, err := embedding.GetEmbeddingProvider(p.Type, p.SubType, p.ClientSecret)
if err != nil {
return nil, err
}

if pProvider == nil {
return nil, fmt.Errorf("the embedding provider type: %s is not supported", p.Type)
}

return pProvider, nil
}

+ 14
- 9
object/store.go View File

@@ -44,8 +44,9 @@ type Store struct {
CreatedTime string `xorm:"varchar(100)" json:"createdTime"` CreatedTime string `xorm:"varchar(100)" json:"createdTime"`
DisplayName string `xorm:"varchar(100)" json:"displayName"` DisplayName string `xorm:"varchar(100)" json:"displayName"`


StorageProvider string `xorm:"varchar(100)" json:"storageProvider"`
ModelProvider string `xorm:"varchar(100)" json:"modelProvider"`
StorageProvider string `xorm:"varchar(100)" json:"storageProvider"`
ModelProvider string `xorm:"varchar(100)" json:"modelProvider"`
EmbeddingProvider string `xorm:"varchar(100)" json:"embeddingProvider"`


FileTree *File `xorm:"mediumtext" json:"fileTree"` FileTree *File `xorm:"mediumtext" json:"fileTree"`
PropertiesMap map[string]*Properties `xorm:"mediumtext" json:"propertiesMap"` PropertiesMap map[string]*Properties `xorm:"mediumtext" json:"propertiesMap"`
@@ -150,22 +151,26 @@ func (store *Store) GetId() string {
return fmt.Sprintf("%s/%s", store.Owner, store.Name) return fmt.Sprintf("%s/%s", store.Owner, store.Name)
} }


func (store *Store) GetModelProvider() (*Provider, error) {
if store.ModelProvider == "" {
return GetDefaultModelProvider()
func (store *Store) GetEmbeddingProvider() (*Provider, error) {
if store.EmbeddingProvider == "" {
return GetDefaultEmbeddingProvider()
} }


providerId := util.GetIdFromOwnerAndName(store.Owner, store.ModelProvider)
providerId := util.GetIdFromOwnerAndName(store.Owner, store.EmbeddingProvider)
return GetProvider(providerId) return GetProvider(providerId)
} }


func RefreshStoreVectors(store *Store) (bool, error) { func RefreshStoreVectors(store *Store) (bool, error) {
provider, err := store.GetModelProvider()
embeddingProvider, err := store.GetEmbeddingProvider()
if err != nil { if err != nil {
return false, err return false, err
} }


authToken := provider.ClientSecret
ok, err := addVectorsForStore(authToken, store.StorageProvider, "", store.Name)
embeddingProviderObj, err := embeddingProvider.GetEmbeddingProvider()
if err != nil {
return false, err
}

ok, err := addVectorsForStore(embeddingProviderObj, store.StorageProvider, "", store.Name)
return ok, err return ok, err
} }

+ 12
- 10
object/vector_embedding.go View File

@@ -20,7 +20,7 @@ import (
"path/filepath" "path/filepath"
"time" "time"


"github.com/casbin/casibase/model"
"github.com/casbin/casibase/embedding"
"github.com/casbin/casibase/storage" "github.com/casbin/casibase/storage"
"github.com/casbin/casibase/txt" "github.com/casbin/casibase/txt"
"github.com/casbin/casibase/util" "github.com/casbin/casibase/util"
@@ -53,8 +53,9 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object,
return filterTextFiles(files), nil return filterTextFiles(files), nil
} }


func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
embedding, err := model.GetEmbeddingSafe(authToken, text)
func addEmbeddedVector(embeddingProviderObj embedding.EmbeddingProvider, text string, storeName string, fileName string) (bool, error) {
data, err := embeddingProviderObj.QueryVector(text, 5)
// data, err := model.GetEmbeddingSafe(authToken, text)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -72,16 +73,16 @@ func addEmbeddedVector(authToken string, text string, storeName string, fileName
Store: storeName, Store: storeName,
File: fileName, File: fileName,
Text: text, Text: text,
Data: embedding,
Data: data,
} }
return AddVector(vector) return AddVector(vector)
} }


func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) {
func addVectorsForStore(embeddingProviderObj embedding.EmbeddingProvider, storageProviderName string, key string, storeName string) (bool, error) {
var affected bool var affected bool
var err error var err error


objs, err := getFilteredFileObjects(provider, key)
objs, err := getFilteredFileObjects(storageProviderName, key)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -99,7 +100,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName
for i, textSection := range textSections { for i, textSection := range textSections {
if timeLimiter.Allow() { if timeLimiter.Allow() {
fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key)
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key)
} else { } else {
err = timeLimiter.Wait(context.Background()) err = timeLimiter.Wait(context.Background())
if err != nil { if err != nil {
@@ -107,7 +108,7 @@ func addVectorsForStore(authToken string, provider string, key string, storeName
} }


fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection) fmt.Printf("[%d/%d] Generating embedding for store: [%s]'s text section: %s\n", i+1, len(textSections), storeName, textSection)
affected, err = addEmbeddedVector(authToken, textSection, storeName, obj.Key)
affected, err = addEmbeddedVector(embeddingProviderObj, textSection, storeName, obj.Key)
} }
} }
} }
@@ -127,8 +128,9 @@ func getRelatedVectors(owner string) ([]*Vector, error) {
return vectors, nil return vectors, nil
} }


func GetNearestVectorText(authToken string, owner string, question string) (string, error) {
qVector, err := model.GetEmbeddingSafe(authToken, question)
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) {
qVector, err := embeddingProvider.QueryVector(text, 5)
// qVector, err := embedding.GetEmbeddingSafe(authToken, question)
if err != nil { if err != nil {
return "", err return "", err
} }


+ 6
- 9
web/src/ProviderEditPage.js View File

@@ -102,8 +102,7 @@ class ProviderEditPage extends React.Component {
{ {
[ [
{id: "Model", name: "Model"}, {id: "Model", name: "Model"},
{id: "Vector Database", name: "Vector Database"},
{id: "Storage", name: "Storage"},
{id: "Embedding", name: "Embedding"},
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) ].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
} }
</Select> </Select>
@@ -116,11 +115,9 @@ class ProviderEditPage extends React.Component {
<Col span={22} > <Col span={22} >
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> <Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}>
{ {
[
{id: "OpenAI", name: "OpenAI"},
{id: "Hugging Face", name: "Hugging Face"},
{id: "Ernie", name: "Ernie"},
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
Setting.getProviderTypeOptions(this.state.provider.category)
// .sort((a, b) => a.name.localeCompare(b.name))
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
} }
</Select> </Select>
</Col> </Col>
@@ -132,8 +129,8 @@ class ProviderEditPage extends React.Component {
<Col span={22} > <Col span={22} >
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}> <Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}>
{ {
Setting.getProviderSubTypeOptions(this.state.provider.type)
.sort((a, b) => a.name.localeCompare(b.name))
Setting.getProviderSubTypeOptions(this.state.provider.category, this.state.provider.type)
// .sort((a, b) => a.name.localeCompare(b.name))
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) .map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
} }
</Select> </Select>


+ 1
- 1
web/src/ProviderListPage.js View File

@@ -117,7 +117,7 @@ class ProviderListPage extends React.Component {
title: i18next.t("general:Display name"), title: i18next.t("general:Display name"),
dataIndex: "displayName", dataIndex: "displayName",
key: "displayName", key: "displayName",
width: "170px",
width: "220px",
sorter: (a, b) => a.displayName.localeCompare(b.displayName), sorter: (a, b) => a.displayName.localeCompare(b.displayName),
}, },
{ {


+ 71
- 25
web/src/Setting.js View File

@@ -653,35 +653,81 @@ export function isResponseDenied(data) {
return data.msg === "Unauthorized operation"; return data.msg === "Unauthorized operation";
} }
export function getProviderSubTypeOptions(type) {
if (type === "OpenAI") {
export function getProviderTypeOptions(category) {
if (category === "Model") {
return (
[
{id: "OpenAI", name: "OpenAI"},
{id: "Hugging Face", name: "Hugging Face"},
{id: "Ernie", name: "Ernie"},
]
);
} else if (category === "Embedding") {
return ( return (
[ [
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"},
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"},
{id: "gpt-4-32k", name: "gpt-4-32k"},
{id: "gpt-4-0613", name: "gpt-4-0613"},
{id: "gpt-4-0314", name: "gpt-4-0314"},
{id: "gpt-4", name: "gpt-4"},
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"},
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"},
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"},
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"},
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"},
{id: "text-davinci-003", name: "text-davinci-003"},
{id: "text-davinci-002", name: "text-davinci-002"},
{id: "text-curie-001", name: "text-curie-001"},
{id: "text-babbage-001", name: "text-babbage-001"},
{id: "text-ada-001", name: "text-ada-001"},
{id: "text-davinci-001", name: "text-davinci-001"},
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"},
{id: "davinci", name: "davinci"},
{id: "curie-instruct-beta", name: "curie-instruct-beta"},
{id: "curie", name: "curie"},
{id: "ada", name: "ada"},
{id: "babbage", name: "babbage"},
{id: "OpenAI", name: "OpenAI"},
] ]
); );
} else {
return [];
}
}
export function getProviderSubTypeOptions(category, type) {
if (type === "OpenAI") {
if (category === "Model") {
return (
[
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"},
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"},
{id: "gpt-4-32k", name: "gpt-4-32k"},
{id: "gpt-4-0613", name: "gpt-4-0613"},
{id: "gpt-4-0314", name: "gpt-4-0314"},
{id: "gpt-4", name: "gpt-4"},
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"},
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"},
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"},
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"},
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"},
{id: "text-davinci-003", name: "text-davinci-003"},
{id: "text-davinci-002", name: "text-davinci-002"},
{id: "text-curie-001", name: "text-curie-001"},
{id: "text-babbage-001", name: "text-babbage-001"},
{id: "text-ada-001", name: "text-ada-001"},
{id: "text-davinci-001", name: "text-davinci-001"},
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"},
{id: "davinci", name: "davinci"},
{id: "curie-instruct-beta", name: "curie-instruct-beta"},
{id: "curie", name: "curie"},
{id: "ada", name: "ada"},
{id: "babbage", name: "babbage"},
]
);
} else if (category === "Embedding") {
return (
[
{id: "1", name: "AdaSimilarity"},
{id: "2", name: "BabbageSimilarity"},
{id: "3", name: "CurieSimilarity"},
{id: "4", name: "DavinciSimilarity"},
{id: "5", name: "AdaSearchDocument"},
{id: "6", name: "AdaSearchQuery"},
{id: "7", name: "BabbageSearchDocument"},
{id: "8", name: "BabbageSearchQuery"},
{id: "9", name: "CurieSearchDocument"},
{id: "10", name: "CurieSearchQuery"},
{id: "11", name: "DavinciSearchDocument"},
{id: "12", name: "DavinciSearchQuery"},
{id: "13", name: "AdaCodeSearchCode"},
{id: "14", name: "AdaCodeSearchText"},
{id: "15", name: "BabbageCodeSearchCode"},
{id: "16", name: "BabbageCodeSearchText"},
{id: "17", name: "AdaEmbeddingV2"},
]
);
} else {
return [];
}
} else if (type === "Hugging Face") { } else if (type === "Hugging Face") {
return ( return (
[ [


+ 14
- 2
web/src/StoreEditPage.js View File

@@ -30,6 +30,7 @@ class StoreEditPage extends React.Component {
storeName: props.match.params.storeName, storeName: props.match.params.storeName,
storageProviders: [], storageProviders: [],
modelProviders: [], modelProviders: [],
embeddingProviders: [],
store: null, store: null,
}; };
} }
@@ -37,7 +38,7 @@ class StoreEditPage extends React.Component {
UNSAFE_componentWillMount() { UNSAFE_componentWillMount() {
this.getStore(); this.getStore();
this.getStorageProviders(); this.getStorageProviders();
this.getModelProviders();
this.getProviders();
} }
getStore() { getStore() {
@@ -70,12 +71,13 @@ class StoreEditPage extends React.Component {
}); });
} }
getModelProviders() {
getProviders() {
ProviderBackend.getProviders(this.props.account.name) ProviderBackend.getProviders(this.props.account.name)
.then((res) => { .then((res) => {
if (res.status === "ok") { if (res.status === "ok") {
this.setState({ this.setState({
modelProviders: res.data.filter(provider => provider.category === "Model"), modelProviders: res.data.filter(provider => provider.category === "Model"),
embeddingProviders: res.data.filter(provider => provider.category === "Embedding"),
}); });
} else { } else {
Setting.showMessage("error", `Failed to get providers: ${res.msg}`); Setting.showMessage("error", `Failed to get providers: ${res.msg}`);
@@ -148,6 +150,16 @@ class StoreEditPage extends React.Component {
} /> } />
</Col> </Col>
</Row> </Row>
<Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("store:Embedding provider")}:
</Col>
<Col span={22} >
<Select virtual={false} style={{width: "100%"}} value={this.state.store.embeddingProvider} onChange={(value => {this.updateStoreField("embeddingProvider", value);})}
options={this.state.embeddingProviders.map((provider) => Setting.getOption(`${provider.displayName} (${provider.name})`, `${provider.name}`))
} />
</Col>
</Row>
<Row style={{marginTop: "20px"}} > <Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("store:File tree")}: {i18next.t("store:File tree")}:


+ 15
- 0
web/src/StoreListPage.js View File

@@ -56,6 +56,7 @@ class StoreListPage extends React.Component {
displayName: `New Store - ${randomName}`, displayName: `New Store - ${randomName}`,
storageProvider: "", storageProvider: "",
modelProvider: "", modelProvider: "",
embeddingProvider: "",
propertiesMap: {}, propertiesMap: {},
}; };
} }
@@ -168,6 +169,20 @@ class StoreListPage extends React.Component {
); );
}, },
}, },
{
title: i18next.t("store:Embedding provider"),
dataIndex: "embeddingProvider",
key: "embeddingProvider",
width: "250px",
sorter: (a, b) => a.embeddingProvider.localeCompare(b.embeddingProvider),
render: (text, record, index) => {
return (
<Link to={`/providers/${text}`}>
{text}
</Link>
);
},
},
{ {
title: i18next.t("general:Action"), title: i18next.t("general:Action"),
dataIndex: "action", dataIndex: "action",


Loading…
Cancel
Save