diff --git a/model/ernie.go b/model/ernie.go index c9e950e..e8066db 100644 --- a/model/ernie.go +++ b/model/ernie.go @@ -26,13 +26,23 @@ import ( ) type ErnieModelProvider struct { - subType string - apiKey string - secretKey string + subType string + apiKey string + secretKey string + temperature float32 + topP float32 + presencePenalty float32 } -func NewErnieModelProvider(subType string, apiKey string, secretKey string) (*ErnieModelProvider, error) { - return &ErnieModelProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil +func NewErnieModelProvider(subType string, apiKey string, secretKey string, temperature float32, topP float32, presencePenalty float32) (*ErnieModelProvider, error) { + return &ErnieModelProvider{ + subType: subType, + apiKey: apiKey, + secretKey: secretKey, + temperature: temperature, + topP: topP, + presencePenalty: presencePenalty, + }, nil } func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { @@ -59,8 +69,18 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde return nil } + temperature := p.temperature + topP := p.topP + presencePenalty := p.presencePenalty + if p.subType == "ERNIE-Bot" { - stream, err := client.CreateErnieBotChatCompletionStream(ctx, ernie.ErnieBotRequest{Messages: messages}) + stream, err := client.CreateErnieBotChatCompletionStream(ctx, + ernie.ErnieBotRequest{ + Messages: messages, + Temperature: temperature, + TopP: topP, + PresencePenalty: presencePenalty, + }) if err != nil { return err } @@ -82,7 +102,13 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde } } } else if p.subType == "ERNIE-Bot-turbo" { - stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, ernie.ErnieBotTurboRequest{Messages: messages}) + stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, + ernie.ErnieBotTurboRequest{ + Messages: messages, + Temperature: temperature, + TopP: topP, + PresencePenalty: presencePenalty, + }) if err != nil { return err } diff --git a/model/provider.go b/model/provider.go index ec92aee..4f9ce50 100644 --- a/model/provider.go +++ b/model/provider.go @@ -35,7 +35,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret } else if typ == "OpenRouter" { p, err = NewOpenRouterModelProvider(subType, clientSecret) } else if typ == "Ernie" { - p, err = NewErnieModelProvider(subType, clientId, clientSecret) + p, err = NewErnieModelProvider(subType, clientId, clientSecret, temperature, topP, presencePenalty) } else if typ == "iFlytek" { p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) } else if typ == "ChatGLM" { diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index 6afd336..86b2067 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -315,6 +315,58 @@ class ProviderEditPage extends React.Component { ) : null } + { + (this.state.provider.category === "Model" && this.state.provider.type === "Ernie") ? ( + <> + + + {i18next.t("provider:Temperature")}: + + { + this.updateProviderField("temperature", value); + }} + isMobile={Setting.isMobile()} + /> + + + + {i18next.t("provider:Top P")}: + + { + this.updateProviderField("topP", value); + }} + isMobile={Setting.isMobile()} + /> + + + + {i18next.t("provider:Presence penalty")}: + + { + this.updateProviderField("presencePenalty", value); + }} + isMobile={Setting.isMobile()} + /> + + + ) : null + } {i18next.t("general:Provider URL")}: