From 55cc17fc4a9212519f686dc779ec715ee4e60481 Mon Sep 17 00:00:00 2001 From: Kelvin Chiu Date: Fri, 22 Sep 2023 23:58:36 +0800 Subject: [PATCH] feat: support iFlytek model parameters (#644) --- model/iflytek.go | 25 +++++++++++++++--------- model/provider.go | 4 ++-- object/provider.go | 3 ++- web/src/ProviderEditPage.js | 38 ++++++++++++++++++++++++++++++++++++- web/src/ProviderListPage.js | 1 + 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/model/iflytek.go b/model/iflytek.go index 0892346..dc9c08f 100644 --- a/model/iflytek.go +++ b/model/iflytek.go @@ -24,18 +24,22 @@ import ( ) type iFlytekModelProvider struct { - subType string - appID string - apiKey string - secretKey string + subType string + appID string + apiKey string + secretKey string + temperature string + topK int } -func NewiFlytekModelProvider(subType string, secretKey string) (*iFlytekModelProvider, error) { +func NewiFlytekModelProvider(subType string, secretKey string, temperature float32, topK int) (*iFlytekModelProvider, error) { p := &iFlytekModelProvider{ - subType: subType, - appID: "", - apiKey: "", - secretKey: secretKey, + subType: subType, + appID: "", + apiKey: "", + secretKey: secretKey, + temperature: fmt.Sprintf("%f", temperature), + topK: topK, } return p, nil } @@ -55,6 +59,9 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, buil return fmt.Errorf("iflytek get session error: session is nil") } + session.Req.Parameter.Chat.Temperature = p.temperature + session.Req.Parameter.Chat.TopK = p.topK + response, err := session.Send(question) if err != nil { return fmt.Errorf("iflytek send error: %v", err) diff --git a/model/provider.go b/model/provider.go index 96fbd50..ec92aee 100644 --- a/model/provider.go +++ b/model/provider.go @@ -23,7 +23,7 @@ type ModelProvider interface { QueryText(question string, writer io.Writer, builder *strings.Builder) error } -func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { +func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { var p ModelProvider var err error if typ == "OpenAI" { @@ -37,7 +37,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret } else if typ == "Ernie" { p, err = NewErnieModelProvider(subType, clientId, clientSecret) } else if typ == "iFlytek" { - p, err = NewiFlytekModelProvider(subType, clientSecret) + p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) } else if typ == "ChatGLM" { p, err = NewChatGLMModelProvider(subType, clientSecret) } diff --git a/object/provider.go b/object/provider.go index 58ff721..8cfb334 100644 --- a/object/provider.go +++ b/object/provider.go @@ -39,6 +39,7 @@ type Provider struct { Temperature float32 `xorm:"float" json:"temperature"` TopP float32 `xorm:"float" json:"topP"` + TopK int `xorm:"int" json:"topK"` FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"` PresencePenalty float32 `xorm:"float" json:"presencePenalty"` } @@ -210,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { } func (p *Provider) GetModelProvider() (model.ModelProvider, error) { - pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.FrequencyPenalty, p.PresencePenalty) + pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty) if err != nil { return nil, err } diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index ebf1f4c..6afd336 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -49,7 +49,7 @@ class ProviderEditPage extends React.Component { } parseProviderField(key, value) { - if ([""].includes(key)) { + if (["topK"].includes(key)) { value = Setting.myParseInt(value); } else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) { value = Setting.myParseFloat(value); @@ -279,6 +279,42 @@ class ProviderEditPage extends React.Component { ) : null } + { + (this.state.provider.category === "Model" && this.state.provider.type === "iFlytek") ? ( + <> + + + {i18next.t("provider:Temperature")}: + + { + this.updateProviderField("temperature", value); + }} + isMobile={Setting.isMobile()} + /> + + + + {i18next.t("provider:Top K")}: + + { + this.updateProviderField("topK", value); + }} + isMobile={Setting.isMobile()} + /> + + + ) : null + } {i18next.t("general:Provider URL")}: diff --git a/web/src/ProviderListPage.js b/web/src/ProviderListPage.js index 6ec6f33..d96790a 100644 --- a/web/src/ProviderListPage.js +++ b/web/src/ProviderListPage.js @@ -60,6 +60,7 @@ class ProviderListPage extends React.Component { clientSecret: "", temperature: 1, topP: 1, + topK: 4, frequencyPenalty: 0, presencePenalty: 0, providerUrl: "https://platform.openai.com/account/api-keys",