@@ -23,9 +23,9 @@ type ModelProvider interface { | |||||
QueryText(question string, writer io.Writer, builder *strings.Builder) error | QueryText(question string, writer io.Writer, builder *strings.Builder) error | ||||
} | } | ||||
func GetModelProvider(typ string, secretKey string) (ModelProvider, error) { | |||||
if typ == "OpenAI API - GPT 3.5" { | |||||
p, err := NewOpenaiGpt3p5ModelProvider(secretKey) | |||||
func GetModelProvider(typ string, subType string, secretKey string) (ModelProvider, error) { | |||||
if typ == "OpenAI API" { | |||||
p, err := NewOpenaiGpt3p5ModelProvider(subType, secretKey) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -25,11 +25,13 @@ import ( | |||||
) | ) | ||||
type OpenaiGpt3p5ModelProvider struct { | type OpenaiGpt3p5ModelProvider struct { | ||||
SubType string | |||||
SecretKey string | SecretKey string | ||||
} | } | ||||
func NewOpenaiGpt3p5ModelProvider(secretKey string) (*OpenaiGpt3p5ModelProvider, error) { | |||||
func NewOpenaiGpt3p5ModelProvider(subType string, secretKey string) (*OpenaiGpt3p5ModelProvider, error) { | |||||
p := &OpenaiGpt3p5ModelProvider{ | p := &OpenaiGpt3p5ModelProvider{ | ||||
SubType: subType, | |||||
SecretKey: secretKey, | SecretKey: secretKey, | ||||
} | } | ||||
return p, nil | return p, nil | ||||
@@ -43,9 +45,15 @@ func (p *OpenaiGpt3p5ModelProvider) QueryText(question string, writer io.Writer, | |||||
if !ok { | if !ok { | ||||
return fmt.Errorf("writer does not implement http.Flusher") | return fmt.Errorf("writer does not implement http.Flusher") | ||||
} | } | ||||
model := p.SubType | |||||
if model == "" { | |||||
model = openai.GPT3TextDavinci003 | |||||
} | |||||
// https://platform.openai.com/tokenizer | // https://platform.openai.com/tokenizer | ||||
// https://github.com/pkoukk/tiktoken-go#available-encodings | // https://github.com/pkoukk/tiktoken-go#available-encodings | ||||
promptTokens, err := GetTokenSize(openai.GPT3TextDavinci003, question) | |||||
promptTokens, err := GetTokenSize(model, question) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
@@ -56,7 +64,7 @@ func (p *OpenaiGpt3p5ModelProvider) QueryText(question string, writer io.Writer, | |||||
respStream, err := client.CreateCompletionStream( | respStream, err := client.CreateCompletionStream( | ||||
ctx, | ctx, | ||||
openai.CompletionRequest{ | openai.CompletionRequest{ | ||||
Model: openai.GPT3TextDavinci003, | |||||
Model: model, | |||||
Prompt: question, | Prompt: question, | ||||
MaxTokens: maxTokens, | MaxTokens: maxTokens, | ||||
Stream: true, | Stream: true, | ||||
@@ -30,6 +30,7 @@ type Provider struct { | |||||
DisplayName string `xorm:"varchar(100)" json:"displayName"` | DisplayName string `xorm:"varchar(100)" json:"displayName"` | ||||
Category string `xorm:"varchar(100)" json:"category"` | Category string `xorm:"varchar(100)" json:"category"` | ||||
Type string `xorm:"varchar(100)" json:"type"` | Type string `xorm:"varchar(100)" json:"type"` | ||||
SubType string `xorm:"varchar(100)" json:"subType"` | |||||
ClientId string `xorm:"varchar(100)" json:"clientId"` | ClientId string `xorm:"varchar(100)" json:"clientId"` | ||||
ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"` | ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"` | ||||
ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` | ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` | ||||
@@ -157,7 +158,7 @@ func (provider *Provider) GetId() string { | |||||
} | } | ||||
func (p *Provider) GetModelProvider() (ai.ModelProvider, error) { | func (p *Provider) GetModelProvider() (ai.ModelProvider, error) { | ||||
pProvider, err := ai.GetModelProvider(p.Type, p.ClientSecret) | |||||
pProvider, err := ai.GetModelProvider(p.Type, p.SubType, p.ClientSecret) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -117,13 +117,26 @@ class ProviderEditPage extends React.Component { | |||||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> | <Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> | ||||
{ | { | ||||
[ | [ | ||||
{id: "OpenAI API - GPT 3.5", name: "OpenAI API - GPT 3.5"}, | |||||
{id: "OpenAI API - GPT 4", name: "OpenAI API - GPT 4"}, | |||||
{id: "OpenAI API", name: "OpenAI API"}, | |||||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | ].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | ||||
} | } | ||||
</Select> | </Select> | ||||
</Col> | </Col> | ||||
</Row> | </Row> | ||||
<Row style={{marginTop: "20px"}} > | |||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||||
{i18next.t("provider:Sub type")}: | |||||
</Col> | |||||
<Col span={22} > | |||||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}> | |||||
{ | |||||
Setting.getProviderSubTypeOptions(this.state.provider.type) | |||||
.sort((a, b) => a.name.localeCompare(b.name)) | |||||
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||||
} | |||||
</Select> | |||||
</Col> | |||||
</Row> | |||||
<Row style={{marginTop: "20px"}} > | <Row style={{marginTop: "20px"}} > | ||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | ||||
{i18next.t("provider:Secret key")}: | {i18next.t("provider:Secret key")}: | ||||
@@ -54,7 +54,8 @@ class ProviderListPage extends React.Component { | |||||
createdTime: moment().format(), | createdTime: moment().format(), | ||||
displayName: `New Provider - ${randomName}`, | displayName: `New Provider - ${randomName}`, | ||||
category: "Model", | category: "Model", | ||||
type: "OpenAI API - GPT 3.5", | |||||
type: "OpenAI API", | |||||
subType: "gpt-3.5-turbo", | |||||
clientId: "", | clientId: "", | ||||
clientSecret: "", | clientSecret: "", | ||||
providerUrl: "https://platform.openai.com/account/api-keys", | providerUrl: "https://platform.openai.com/account/api-keys", | ||||
@@ -102,7 +103,7 @@ class ProviderListPage extends React.Component { | |||||
title: i18next.t("general:Name"), | title: i18next.t("general:Name"), | ||||
dataIndex: "name", | dataIndex: "name", | ||||
key: "name", | key: "name", | ||||
width: "140px", | |||||
width: "160px", | |||||
sorter: (a, b) => a.name.localeCompare(b.name), | sorter: (a, b) => a.name.localeCompare(b.name), | ||||
render: (text, record, index) => { | render: (text, record, index) => { | ||||
return ( | return ( | ||||
@@ -133,6 +134,13 @@ class ProviderListPage extends React.Component { | |||||
width: "160px", | width: "160px", | ||||
sorter: (a, b) => a.type.localeCompare(b.type), | sorter: (a, b) => a.type.localeCompare(b.type), | ||||
}, | }, | ||||
{ | |||||
title: i18next.t("provider:Sub type"), | |||||
dataIndex: "subType", | |||||
key: "subType", | |||||
width: "160px", | |||||
sorter: (a, b) => a.subType.localeCompare(b.subType), | |||||
}, | |||||
{ | { | ||||
title: i18next.t("provider:Secret key"), | title: i18next.t("provider:Secret key"), | ||||
dataIndex: "clientSecret", | dataIndex: "clientSecret", | ||||
@@ -144,13 +152,13 @@ class ProviderListPage extends React.Component { | |||||
title: i18next.t("provider:Provider URL"), | title: i18next.t("provider:Provider URL"), | ||||
dataIndex: "providerUrl", | dataIndex: "providerUrl", | ||||
key: "providerUrl", | key: "providerUrl", | ||||
width: "250px", | |||||
// width: "250px", | |||||
sorter: (a, b) => a.providerUrl.localeCompare(b.providerUrl), | sorter: (a, b) => a.providerUrl.localeCompare(b.providerUrl), | ||||
render: (text, record, index) => { | render: (text, record, index) => { | ||||
return ( | return ( | ||||
<a target="_blank" rel="noreferrer" href={text}> | <a target="_blank" rel="noreferrer" href={text}> | ||||
{ | { | ||||
Setting.getShortText(text) | |||||
Setting.getShortText(text, 80) | |||||
} | } | ||||
</a> | </a> | ||||
); | ); | ||||
@@ -652,3 +652,37 @@ export function renderExternalLink() { | |||||
export function isResponseDenied(data) { | export function isResponseDenied(data) { | ||||
return data.msg === "Unauthorized operation"; | return data.msg === "Unauthorized operation"; | ||||
} | } | ||||
export function getProviderSubTypeOptions(type) { | |||||
if (type === "OpenAI API") { | |||||
return ( | |||||
[ | |||||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||||
{id: "gpt-4", name: "gpt-4"}, | |||||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||||
{id: "text-curie-001", name: "text-curie-001"}, | |||||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||||
{id: "text-ada-001", name: "text-ada-001"}, | |||||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||||
{id: "davinci", name: "davinci"}, | |||||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||||
{id: "curie", name: "curie"}, | |||||
{id: "ada", name: "ada"}, | |||||
{id: "babbage", name: "babbage"}, | |||||
] | |||||
); | |||||
} else { | |||||
return []; | |||||
} | |||||
} |