Browse Source

feat: support iFlytek model parameters (#644)

HEAD
Kelvin Chiu GitHub 2 years ago
parent
commit
55cc17fc4a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 13 deletions
  1. +16
    -9
      model/iflytek.go
  2. +2
    -2
      model/provider.go
  3. +2
    -1
      object/provider.go
  4. +37
    -1
      web/src/ProviderEditPage.js
  5. +1
    -0
      web/src/ProviderListPage.js

+ 16
- 9
model/iflytek.go View File

@@ -24,18 +24,22 @@ import (
) )


type iFlytekModelProvider struct { type iFlytekModelProvider struct {
subType string
appID string
apiKey string
secretKey string
subType string
appID string
apiKey string
secretKey string
temperature string
topK int
} }


func NewiFlytekModelProvider(subType string, secretKey string) (*iFlytekModelProvider, error) {
func NewiFlytekModelProvider(subType string, secretKey string, temperature float32, topK int) (*iFlytekModelProvider, error) {
p := &iFlytekModelProvider{ p := &iFlytekModelProvider{
subType: subType,
appID: "",
apiKey: "",
secretKey: secretKey,
subType: subType,
appID: "",
apiKey: "",
secretKey: secretKey,
temperature: fmt.Sprintf("%f", temperature),
topK: topK,
} }
return p, nil return p, nil
} }
@@ -55,6 +59,9 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, buil
return fmt.Errorf("iflytek get session error: session is nil") return fmt.Errorf("iflytek get session error: session is nil")
} }


session.Req.Parameter.Chat.Temperature = p.temperature
session.Req.Parameter.Chat.TopK = p.topK

response, err := session.Send(question) response, err := session.Send(question)
if err != nil { if err != nil {
return fmt.Errorf("iflytek send error: %v", err) return fmt.Errorf("iflytek send error: %v", err)


+ 2
- 2
model/provider.go View File

@@ -23,7 +23,7 @@ type ModelProvider interface {
QueryText(question string, writer io.Writer, builder *strings.Builder) error QueryText(question string, writer io.Writer, builder *strings.Builder) error
} }


func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) {
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) {
var p ModelProvider var p ModelProvider
var err error var err error
if typ == "OpenAI" { if typ == "OpenAI" {
@@ -37,7 +37,7 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret
} else if typ == "Ernie" { } else if typ == "Ernie" {
p, err = NewErnieModelProvider(subType, clientId, clientSecret) p, err = NewErnieModelProvider(subType, clientId, clientSecret)
} else if typ == "iFlytek" { } else if typ == "iFlytek" {
p, err = NewiFlytekModelProvider(subType, clientSecret)
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK)
} else if typ == "ChatGLM" { } else if typ == "ChatGLM" {
p, err = NewChatGLMModelProvider(subType, clientSecret) p, err = NewChatGLMModelProvider(subType, clientSecret)
} }


+ 2
- 1
object/provider.go View File

@@ -39,6 +39,7 @@ type Provider struct {


Temperature float32 `xorm:"float" json:"temperature"` Temperature float32 `xorm:"float" json:"temperature"`
TopP float32 `xorm:"float" json:"topP"` TopP float32 `xorm:"float" json:"topP"`
TopK int `xorm:"int" json:"topK"`
FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"` FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"`
PresencePenalty float32 `xorm:"float" json:"presencePenalty"` PresencePenalty float32 `xorm:"float" json:"presencePenalty"`
} }
@@ -210,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) {
} }


func (p *Provider) GetModelProvider() (model.ModelProvider, error) { func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.FrequencyPenalty, p.PresencePenalty)
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty)
if err != nil { if err != nil {
return nil, err return nil, err
} }


+ 37
- 1
web/src/ProviderEditPage.js View File

@@ -49,7 +49,7 @@ class ProviderEditPage extends React.Component {
} }
parseProviderField(key, value) { parseProviderField(key, value) {
if ([""].includes(key)) {
if (["topK"].includes(key)) {
value = Setting.myParseInt(value); value = Setting.myParseInt(value);
} else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) { } else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) {
value = Setting.myParseFloat(value); value = Setting.myParseFloat(value);
@@ -279,6 +279,42 @@ class ProviderEditPage extends React.Component {
</> </>
) : null ) : null
} }
{
(this.state.provider.category === "Model" && this.state.provider.type === "iFlytek") ? (
<>
<Row style={{marginTop: "20px"}}>
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("provider:Temperature")}:
</Col>
<this.InputSlider
min={0}
max={1}
step={0.01}
value={this.state.provider.temperature}
onChange={(value) => {
this.updateProviderField("temperature", value);
}}
isMobile={Setting.isMobile()}
/>
</Row>
<Row style={{marginTop: "20px"}}>
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("provider:Top K")}:
</Col>
<this.InputSlider
min={1}
max={6}
step={1}
value={this.state.provider.topK}
onChange={(value) => {
this.updateProviderField("topK", value);
}}
isMobile={Setting.isMobile()}
/>
</Row>
</>
) : null
}
<Row style={{marginTop: "20px"}} > <Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{i18next.t("general:Provider URL")}: {i18next.t("general:Provider URL")}:


+ 1
- 0
web/src/ProviderListPage.js View File

@@ -60,6 +60,7 @@ class ProviderListPage extends React.Component {
clientSecret: "", clientSecret: "",
temperature: 1, temperature: 1,
topP: 1, topP: 1,
topK: 4,
frequencyPenalty: 0, frequencyPenalty: 0,
presencePenalty: 0, presencePenalty: 0,
providerUrl: "https://platform.openai.com/account/api-keys", providerUrl: "https://platform.openai.com/account/api-keys",


Loading…
Cancel
Save