diff --git a/model/ernie.go b/model/ernie.go
index c9e950e..e8066db 100644
--- a/model/ernie.go
+++ b/model/ernie.go
@@ -26,13 +26,23 @@ import (
)
type ErnieModelProvider struct {
- subType string
- apiKey string
- secretKey string
+ subType string
+ apiKey string
+ secretKey string
+ temperature float32
+ topP float32
+ presencePenalty float32
}
-func NewErnieModelProvider(subType string, apiKey string, secretKey string) (*ErnieModelProvider, error) {
- return &ErnieModelProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil
+func NewErnieModelProvider(subType string, apiKey string, secretKey string, temperature float32, topP float32, presencePenalty float32) (*ErnieModelProvider, error) {
+ return &ErnieModelProvider{
+ subType: subType,
+ apiKey: apiKey,
+ secretKey: secretKey,
+ temperature: temperature,
+ topP: topP,
+ presencePenalty: presencePenalty,
+ }, nil
}
func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
@@ -59,8 +69,18 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde
return nil
}
+ temperature := p.temperature
+ topP := p.topP
+ presencePenalty := p.presencePenalty
+
if p.subType == "ERNIE-Bot" {
- stream, err := client.CreateErnieBotChatCompletionStream(ctx, ernie.ErnieBotRequest{Messages: messages})
+ stream, err := client.CreateErnieBotChatCompletionStream(ctx,
+ ernie.ErnieBotRequest{
+ Messages: messages,
+ Temperature: temperature,
+ TopP: topP,
+ PresencePenalty: presencePenalty,
+ })
if err != nil {
return err
}
@@ -82,7 +102,13 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde
}
}
} else if p.subType == "ERNIE-Bot-turbo" {
- stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, ernie.ErnieBotTurboRequest{Messages: messages})
+ stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx,
+ ernie.ErnieBotTurboRequest{
+ Messages: messages,
+ Temperature: temperature,
+ TopP: topP,
+ PresencePenalty: presencePenalty,
+ })
if err != nil {
return err
}
diff --git a/model/provider.go b/model/provider.go
index ec92aee..4f9ce50 100644
--- a/model/provider.go
+++ b/model/provider.go
@@ -35,7 +35,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret
} else if typ == "OpenRouter" {
p, err = NewOpenRouterModelProvider(subType, clientSecret)
} else if typ == "Ernie" {
- p, err = NewErnieModelProvider(subType, clientId, clientSecret)
+ p, err = NewErnieModelProvider(subType, clientId, clientSecret, temperature, topP, presencePenalty)
} else if typ == "iFlytek" {
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK)
} else if typ == "ChatGLM" {
diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js
index 6afd336..86b2067 100644
--- a/web/src/ProviderEditPage.js
+++ b/web/src/ProviderEditPage.js
@@ -315,6 +315,58 @@ class ProviderEditPage extends React.Component {
>
) : null
}
+ {
+ (this.state.provider.category === "Model" && this.state.provider.type === "Ernie") ? (
+ <>
+
+
+ {i18next.t("provider:Temperature")}:
+
+ {
+ this.updateProviderField("temperature", value);
+ }}
+ isMobile={Setting.isMobile()}
+ />
+
+
+
+ {i18next.t("provider:Top P")}:
+
+ {
+ this.updateProviderField("topP", value);
+ }}
+ isMobile={Setting.isMobile()}
+ />
+
+
+
+ {i18next.t("provider:Presence penalty")}:
+
+ {
+ this.updateProviderField("presencePenalty", value);
+ }}
+ isMobile={Setting.isMobile()}
+ />
+
+ >
+ ) : null
+ }
{i18next.t("general:Provider URL")}: