@@ -23,9 +23,9 @@ type ModelProvider interface { | |||
QueryText(question string, writer io.Writer, builder *strings.Builder) error | |||
} | |||
func GetModelProvider(typ string, secretKey string) (ModelProvider, error) { | |||
if typ == "OpenAI API - GPT 3.5" { | |||
p, err := NewOpenaiGpt3p5ModelProvider(secretKey) | |||
func GetModelProvider(typ string, subType string, secretKey string) (ModelProvider, error) { | |||
if typ == "OpenAI API" { | |||
p, err := NewOpenaiGpt3p5ModelProvider(subType, secretKey) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -25,11 +25,13 @@ import ( | |||
) | |||
type OpenaiGpt3p5ModelProvider struct { | |||
SubType string | |||
SecretKey string | |||
} | |||
func NewOpenaiGpt3p5ModelProvider(secretKey string) (*OpenaiGpt3p5ModelProvider, error) { | |||
func NewOpenaiGpt3p5ModelProvider(subType string, secretKey string) (*OpenaiGpt3p5ModelProvider, error) { | |||
p := &OpenaiGpt3p5ModelProvider{ | |||
SubType: subType, | |||
SecretKey: secretKey, | |||
} | |||
return p, nil | |||
@@ -43,9 +45,15 @@ func (p *OpenaiGpt3p5ModelProvider) QueryText(question string, writer io.Writer, | |||
if !ok { | |||
return fmt.Errorf("writer does not implement http.Flusher") | |||
} | |||
model := p.SubType | |||
if model == "" { | |||
model = openai.GPT3TextDavinci003 | |||
} | |||
// https://platform.openai.com/tokenizer | |||
// https://github.com/pkoukk/tiktoken-go#available-encodings | |||
promptTokens, err := GetTokenSize(openai.GPT3TextDavinci003, question) | |||
promptTokens, err := GetTokenSize(model, question) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -56,7 +64,7 @@ func (p *OpenaiGpt3p5ModelProvider) QueryText(question string, writer io.Writer, | |||
respStream, err := client.CreateCompletionStream( | |||
ctx, | |||
openai.CompletionRequest{ | |||
Model: openai.GPT3TextDavinci003, | |||
Model: model, | |||
Prompt: question, | |||
MaxTokens: maxTokens, | |||
Stream: true, | |||
@@ -30,6 +30,7 @@ type Provider struct { | |||
DisplayName string `xorm:"varchar(100)" json:"displayName"` | |||
Category string `xorm:"varchar(100)" json:"category"` | |||
Type string `xorm:"varchar(100)" json:"type"` | |||
SubType string `xorm:"varchar(100)" json:"subType"` | |||
ClientId string `xorm:"varchar(100)" json:"clientId"` | |||
ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"` | |||
ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` | |||
@@ -157,7 +158,7 @@ func (provider *Provider) GetId() string { | |||
} | |||
func (p *Provider) GetModelProvider() (ai.ModelProvider, error) { | |||
pProvider, err := ai.GetModelProvider(p.Type, p.ClientSecret) | |||
pProvider, err := ai.GetModelProvider(p.Type, p.SubType, p.ClientSecret) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -117,13 +117,26 @@ class ProviderEditPage extends React.Component { | |||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.type} onChange={(value => {this.updateProviderField("type", value);})}> | |||
{ | |||
[ | |||
{id: "OpenAI API - GPT 3.5", name: "OpenAI API - GPT 3.5"}, | |||
{id: "OpenAI API - GPT 4", name: "OpenAI API - GPT 4"}, | |||
{id: "OpenAI API", name: "OpenAI API"}, | |||
].map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||
} | |||
</Select> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Sub type")}: | |||
</Col> | |||
<Col span={22} > | |||
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}> | |||
{ | |||
Setting.getProviderSubTypeOptions(this.state.provider.type) | |||
.sort((a, b) => a.name.localeCompare(b.name)) | |||
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>) | |||
} | |||
</Select> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Secret key")}: | |||
@@ -54,7 +54,8 @@ class ProviderListPage extends React.Component { | |||
createdTime: moment().format(), | |||
displayName: `New Provider - ${randomName}`, | |||
category: "Model", | |||
type: "OpenAI API - GPT 3.5", | |||
type: "OpenAI API", | |||
subType: "gpt-3.5-turbo", | |||
clientId: "", | |||
clientSecret: "", | |||
providerUrl: "https://platform.openai.com/account/api-keys", | |||
@@ -102,7 +103,7 @@ class ProviderListPage extends React.Component { | |||
title: i18next.t("general:Name"), | |||
dataIndex: "name", | |||
key: "name", | |||
width: "140px", | |||
width: "160px", | |||
sorter: (a, b) => a.name.localeCompare(b.name), | |||
render: (text, record, index) => { | |||
return ( | |||
@@ -133,6 +134,13 @@ class ProviderListPage extends React.Component { | |||
width: "160px", | |||
sorter: (a, b) => a.type.localeCompare(b.type), | |||
}, | |||
{ | |||
title: i18next.t("provider:Sub type"), | |||
dataIndex: "subType", | |||
key: "subType", | |||
width: "160px", | |||
sorter: (a, b) => a.subType.localeCompare(b.subType), | |||
}, | |||
{ | |||
title: i18next.t("provider:Secret key"), | |||
dataIndex: "clientSecret", | |||
@@ -144,13 +152,13 @@ class ProviderListPage extends React.Component { | |||
title: i18next.t("provider:Provider URL"), | |||
dataIndex: "providerUrl", | |||
key: "providerUrl", | |||
width: "250px", | |||
// width: "250px", | |||
sorter: (a, b) => a.providerUrl.localeCompare(b.providerUrl), | |||
render: (text, record, index) => { | |||
return ( | |||
<a target="_blank" rel="noreferrer" href={text}> | |||
{ | |||
Setting.getShortText(text) | |||
Setting.getShortText(text, 80) | |||
} | |||
</a> | |||
); | |||
@@ -652,3 +652,37 @@ export function renderExternalLink() { | |||
export function isResponseDenied(data) { | |||
return data.msg === "Unauthorized operation"; | |||
} | |||
export function getProviderSubTypeOptions(type) { | |||
if (type === "OpenAI API") { | |||
return ( | |||
[ | |||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
{id: "gpt-4", name: "gpt-4"}, | |||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||
{id: "text-curie-001", name: "text-curie-001"}, | |||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||
{id: "text-ada-001", name: "text-ada-001"}, | |||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
{id: "davinci", name: "davinci"}, | |||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
{id: "curie", name: "curie"}, | |||
{id: "ada", name: "ada"}, | |||
{id: "babbage", name: "babbage"}, | |||
] | |||
); | |||
} else { | |||
return []; | |||
} | |||
} |