@@ -26,13 +26,23 @@ import ( | |||
) | |||
type ErnieModelProvider struct { | |||
subType string | |||
apiKey string | |||
secretKey string | |||
subType string | |||
apiKey string | |||
secretKey string | |||
temperature float32 | |||
topP float32 | |||
presencePenalty float32 | |||
} | |||
func NewErnieModelProvider(subType string, apiKey string, secretKey string) (*ErnieModelProvider, error) { | |||
return &ErnieModelProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil | |||
func NewErnieModelProvider(subType string, apiKey string, secretKey string, temperature float32, topP float32, presencePenalty float32) (*ErnieModelProvider, error) { | |||
return &ErnieModelProvider{ | |||
subType: subType, | |||
apiKey: apiKey, | |||
secretKey: secretKey, | |||
temperature: temperature, | |||
topP: topP, | |||
presencePenalty: presencePenalty, | |||
}, nil | |||
} | |||
func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | |||
@@ -59,8 +69,18 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde | |||
return nil | |||
} | |||
temperature := p.temperature | |||
topP := p.topP | |||
presencePenalty := p.presencePenalty | |||
if p.subType == "ERNIE-Bot" { | |||
stream, err := client.CreateErnieBotChatCompletionStream(ctx, ernie.ErnieBotRequest{Messages: messages}) | |||
stream, err := client.CreateErnieBotChatCompletionStream(ctx, | |||
ernie.ErnieBotRequest{ | |||
Messages: messages, | |||
Temperature: temperature, | |||
TopP: topP, | |||
PresencePenalty: presencePenalty, | |||
}) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -82,7 +102,13 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde | |||
} | |||
} | |||
} else if p.subType == "ERNIE-Bot-turbo" { | |||
stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, ernie.ErnieBotTurboRequest{Messages: messages}) | |||
stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, | |||
ernie.ErnieBotTurboRequest{ | |||
Messages: messages, | |||
Temperature: temperature, | |||
TopP: topP, | |||
PresencePenalty: presencePenalty, | |||
}) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -35,7 +35,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||
} else if typ == "OpenRouter" { | |||
p, err = NewOpenRouterModelProvider(subType, clientSecret) | |||
} else if typ == "Ernie" { | |||
p, err = NewErnieModelProvider(subType, clientId, clientSecret) | |||
p, err = NewErnieModelProvider(subType, clientId, clientSecret, temperature, topP, presencePenalty) | |||
} else if typ == "iFlytek" { | |||
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | |||
} else if typ == "ChatGLM" { | |||
@@ -315,6 +315,58 @@ class ProviderEditPage extends React.Component { | |||
</> | |||
) : null | |||
} | |||
{ | |||
(this.state.provider.category === "Model" && this.state.provider.type === "Ernie") ? ( | |||
<> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Temperature")}: | |||
</Col> | |||
<this.InputSlider | |||
min={0.01} | |||
max={1} | |||
step={0.01} | |||
value={this.state.provider.temperature} | |||
onChange={(value) => { | |||
this.updateProviderField("temperature", value); | |||
}} | |||
isMobile={Setting.isMobile()} | |||
/> | |||
</Row> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Top P")}: | |||
</Col> | |||
<this.InputSlider | |||
min={0} | |||
max={1} | |||
step={0.01} | |||
value={this.state.provider.topP} | |||
onChange={(value) => { | |||
this.updateProviderField("topP", value); | |||
}} | |||
isMobile={Setting.isMobile()} | |||
/> | |||
</Row> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Presence penalty")}: | |||
</Col> | |||
<this.InputSlider | |||
label={i18next.t("provider:Presence penalty")} | |||
min={1} | |||
max={2} | |||
step={0.01} | |||
value={this.state.provider.presencePenalty} | |||
onChange={(value) => { | |||
this.updateProviderField("presencePenalty", value); | |||
}} | |||
isMobile={Setting.isMobile()} | |||
/> | |||
</Row> | |||
</> | |||
) : null | |||
} | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("general:Provider URL")}: | |||