@@ -24,18 +24,22 @@ import ( | |||||
) | ) | ||||
type iFlytekModelProvider struct { | type iFlytekModelProvider struct { | ||||
subType string | |||||
appID string | |||||
apiKey string | |||||
secretKey string | |||||
subType string | |||||
appID string | |||||
apiKey string | |||||
secretKey string | |||||
temperature string | |||||
topK int | |||||
} | } | ||||
func NewiFlytekModelProvider(subType string, secretKey string) (*iFlytekModelProvider, error) { | |||||
func NewiFlytekModelProvider(subType string, secretKey string, temperature float32, topK int) (*iFlytekModelProvider, error) { | |||||
p := &iFlytekModelProvider{ | p := &iFlytekModelProvider{ | ||||
subType: subType, | |||||
appID: "", | |||||
apiKey: "", | |||||
secretKey: secretKey, | |||||
subType: subType, | |||||
appID: "", | |||||
apiKey: "", | |||||
secretKey: secretKey, | |||||
temperature: fmt.Sprintf("%f", temperature), | |||||
topK: topK, | |||||
} | } | ||||
return p, nil | return p, nil | ||||
} | } | ||||
@@ -55,6 +59,9 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, buil | |||||
return fmt.Errorf("iflytek get session error: session is nil") | return fmt.Errorf("iflytek get session error: session is nil") | ||||
} | } | ||||
session.Req.Parameter.Chat.Temperature = p.temperature | |||||
session.Req.Parameter.Chat.TopK = p.topK | |||||
response, err := session.Send(question) | response, err := session.Send(question) | ||||
if err != nil { | if err != nil { | ||||
return fmt.Errorf("iflytek send error: %v", err) | return fmt.Errorf("iflytek send error: %v", err) | ||||
@@ -23,7 +23,7 @@ type ModelProvider interface { | |||||
QueryText(question string, writer io.Writer, builder *strings.Builder) error | QueryText(question string, writer io.Writer, builder *strings.Builder) error | ||||
} | } | ||||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { | |||||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { | |||||
var p ModelProvider | var p ModelProvider | ||||
var err error | var err error | ||||
if typ == "OpenAI" { | if typ == "OpenAI" { | ||||
@@ -37,7 +37,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||||
} else if typ == "Ernie" { | } else if typ == "Ernie" { | ||||
p, err = NewErnieModelProvider(subType, clientId, clientSecret) | p, err = NewErnieModelProvider(subType, clientId, clientSecret) | ||||
} else if typ == "iFlytek" { | } else if typ == "iFlytek" { | ||||
p, err = NewiFlytekModelProvider(subType, clientSecret) | |||||
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | |||||
} else if typ == "ChatGLM" { | } else if typ == "ChatGLM" { | ||||
p, err = NewChatGLMModelProvider(subType, clientSecret) | p, err = NewChatGLMModelProvider(subType, clientSecret) | ||||
} | } | ||||
@@ -39,6 +39,7 @@ type Provider struct { | |||||
Temperature float32 `xorm:"float" json:"temperature"` | Temperature float32 `xorm:"float" json:"temperature"` | ||||
TopP float32 `xorm:"float" json:"topP"` | TopP float32 `xorm:"float" json:"topP"` | ||||
TopK int `xorm:"int" json:"topK"` | |||||
FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"` | FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"` | ||||
PresencePenalty float32 `xorm:"float" json:"presencePenalty"` | PresencePenalty float32 `xorm:"float" json:"presencePenalty"` | ||||
} | } | ||||
@@ -210,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { | |||||
} | } | ||||
func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | ||||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.FrequencyPenalty, p.PresencePenalty) | |||||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -49,7 +49,7 @@ class ProviderEditPage extends React.Component { | |||||
} | } | ||||
parseProviderField(key, value) { | parseProviderField(key, value) { | ||||
if ([""].includes(key)) { | |||||
if (["topK"].includes(key)) { | |||||
value = Setting.myParseInt(value); | value = Setting.myParseInt(value); | ||||
} else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) { | } else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) { | ||||
value = Setting.myParseFloat(value); | value = Setting.myParseFloat(value); | ||||
@@ -279,6 +279,42 @@ class ProviderEditPage extends React.Component { | |||||
</> | </> | ||||
) : null | ) : null | ||||
} | } | ||||
{ | |||||
(this.state.provider.category === "Model" && this.state.provider.type === "iFlytek") ? ( | |||||
<> | |||||
<Row style={{marginTop: "20px"}}> | |||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||||
{i18next.t("provider:Temperature")}: | |||||
</Col> | |||||
<this.InputSlider | |||||
min={0} | |||||
max={1} | |||||
step={0.01} | |||||
value={this.state.provider.temperature} | |||||
onChange={(value) => { | |||||
this.updateProviderField("temperature", value); | |||||
}} | |||||
isMobile={Setting.isMobile()} | |||||
/> | |||||
</Row> | |||||
<Row style={{marginTop: "20px"}}> | |||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||||
{i18next.t("provider:Top K")}: | |||||
</Col> | |||||
<this.InputSlider | |||||
min={1} | |||||
max={6} | |||||
step={1} | |||||
value={this.state.provider.topK} | |||||
onChange={(value) => { | |||||
this.updateProviderField("topK", value); | |||||
}} | |||||
isMobile={Setting.isMobile()} | |||||
/> | |||||
</Row> | |||||
</> | |||||
) : null | |||||
} | |||||
<Row style={{marginTop: "20px"}} > | <Row style={{marginTop: "20px"}} > | ||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | ||||
{i18next.t("general:Provider URL")}: | {i18next.t("general:Provider URL")}: | ||||
@@ -60,6 +60,7 @@ class ProviderListPage extends React.Component { | |||||
clientSecret: "", | clientSecret: "", | ||||
temperature: 1, | temperature: 1, | ||||
topP: 1, | topP: 1, | ||||
topK: 4, | |||||
frequencyPenalty: 0, | frequencyPenalty: 0, | ||||
presencePenalty: 0, | presencePenalty: 0, | ||||
providerUrl: "https://platform.openai.com/account/api-keys", | providerUrl: "https://platform.openai.com/account/api-keys", | ||||