@@ -24,18 +24,22 @@ import ( | |||
) | |||
type iFlytekModelProvider struct { | |||
subType string | |||
appID string | |||
apiKey string | |||
secretKey string | |||
subType string | |||
appID string | |||
apiKey string | |||
secretKey string | |||
temperature string | |||
topK int | |||
} | |||
func NewiFlytekModelProvider(subType string, secretKey string) (*iFlytekModelProvider, error) { | |||
func NewiFlytekModelProvider(subType string, secretKey string, temperature float32, topK int) (*iFlytekModelProvider, error) { | |||
p := &iFlytekModelProvider{ | |||
subType: subType, | |||
appID: "", | |||
apiKey: "", | |||
secretKey: secretKey, | |||
subType: subType, | |||
appID: "", | |||
apiKey: "", | |||
secretKey: secretKey, | |||
temperature: fmt.Sprintf("%f", temperature), | |||
topK: topK, | |||
} | |||
return p, nil | |||
} | |||
@@ -55,6 +59,9 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, buil | |||
return fmt.Errorf("iflytek get session error: session is nil") | |||
} | |||
session.Req.Parameter.Chat.Temperature = p.temperature | |||
session.Req.Parameter.Chat.TopK = p.topK | |||
response, err := session.Send(question) | |||
if err != nil { | |||
return fmt.Errorf("iflytek send error: %v", err) | |||
@@ -23,7 +23,7 @@ type ModelProvider interface { | |||
QueryText(question string, writer io.Writer, builder *strings.Builder) error | |||
} | |||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { | |||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { | |||
var p ModelProvider | |||
var err error | |||
if typ == "OpenAI" { | |||
@@ -37,7 +37,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||
} else if typ == "Ernie" { | |||
p, err = NewErnieModelProvider(subType, clientId, clientSecret) | |||
} else if typ == "iFlytek" { | |||
p, err = NewiFlytekModelProvider(subType, clientSecret) | |||
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | |||
} else if typ == "ChatGLM" { | |||
p, err = NewChatGLMModelProvider(subType, clientSecret) | |||
} | |||
@@ -39,6 +39,7 @@ type Provider struct { | |||
Temperature float32 `xorm:"float" json:"temperature"` | |||
TopP float32 `xorm:"float" json:"topP"` | |||
TopK int `xorm:"int" json:"topK"` | |||
FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"` | |||
PresencePenalty float32 `xorm:"float" json:"presencePenalty"` | |||
} | |||
@@ -210,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { | |||
} | |||
func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.FrequencyPenalty, p.PresencePenalty) | |||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -49,7 +49,7 @@ class ProviderEditPage extends React.Component { | |||
} | |||
parseProviderField(key, value) { | |||
if ([""].includes(key)) { | |||
if (["topK"].includes(key)) { | |||
value = Setting.myParseInt(value); | |||
} else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) { | |||
value = Setting.myParseFloat(value); | |||
@@ -279,6 +279,42 @@ class ProviderEditPage extends React.Component { | |||
</> | |||
) : null | |||
} | |||
{ | |||
(this.state.provider.category === "Model" && this.state.provider.type === "iFlytek") ? ( | |||
<> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Temperature")}: | |||
</Col> | |||
<this.InputSlider | |||
min={0} | |||
max={1} | |||
step={0.01} | |||
value={this.state.provider.temperature} | |||
onChange={(value) => { | |||
this.updateProviderField("temperature", value); | |||
}} | |||
isMobile={Setting.isMobile()} | |||
/> | |||
</Row> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Top K")}: | |||
</Col> | |||
<this.InputSlider | |||
min={1} | |||
max={6} | |||
step={1} | |||
value={this.state.provider.topK} | |||
onChange={(value) => { | |||
this.updateProviderField("topK", value); | |||
}} | |||
isMobile={Setting.isMobile()} | |||
/> | |||
</Row> | |||
</> | |||
) : null | |||
} | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("general:Provider URL")}: | |||
@@ -60,6 +60,7 @@ class ProviderListPage extends React.Component { | |||
clientSecret: "", | |||
temperature: 1, | |||
topP: 1, | |||
topK: 4, | |||
frequencyPenalty: 0, | |||
presencePenalty: 0, | |||
providerUrl: "https://platform.openai.com/account/api-keys", | |||