* feat: support Azure model provider * fix: convert apiVersion to stringmaster
@@ -0,0 +1,55 @@ | |||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||
// | |||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||
// you may not use this file except in compliance with the License. | |||
// You may obtain a copy of the License at | |||
// | |||
// http://www.apache.org/licenses/LICENSE-2.0 | |||
// | |||
// Unless required by applicable law or agreed to in writing, software | |||
// distributed under the License is distributed on an "AS IS" BASIS, | |||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
// See the License for the specific language governing permissions and | |||
// limitations under the License. | |||
package model | |||
import ( | |||
"github.com/casbin/casibase/proxy" | |||
"github.com/sashabaranov/go-openai" | |||
) | |||
func NewAzureModelProvider(typ string, subType string, deploymentName string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string, apiVersion string) (*LocalModelProvider, error) { | |||
p := &LocalModelProvider{ | |||
typ: typ, | |||
subType: subType, | |||
deploymentName: deploymentName, | |||
secretKey: secretKey, | |||
temperature: temperature, | |||
topP: topP, | |||
frequencyPenalty: frequencyPenalty, | |||
presencePenalty: presencePenalty, | |||
providerUrl: providerUrl, | |||
apiVersion: apiVersion, | |||
} | |||
return p, nil | |||
} | |||
func getAzureClientFromToken(subtype string, deploymentName string, authToken string, url string, apiVersion string) *openai.Client { | |||
config := openai.DefaultAzureConfig(authToken, url) | |||
config.HTTPClient = proxy.ProxyHttpClient | |||
if apiVersion != "" { | |||
config.APIVersion = apiVersion | |||
} | |||
if deploymentName != "" { | |||
config.AzureModelMapperFunc = func(model string) string { | |||
azureModelMapping := map[string]string{ | |||
subtype: deploymentName, | |||
} | |||
return azureModelMapping[model] | |||
} | |||
} | |||
c := openai.NewClientWithConfig(config) | |||
return c | |||
} |
@@ -25,17 +25,21 @@ import ( | |||
) | |||
type LocalModelProvider struct { | |||
typ string | |||
subType string | |||
deploymentName string | |||
secretKey string | |||
temperature float32 | |||
topP float32 | |||
frequencyPenalty float32 | |||
presencePenalty float32 | |||
providerUrl string | |||
apiVersion string | |||
} | |||
func NewLocalModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { | |||
func NewLocalModelProvider(typ string, subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { | |||
p := &LocalModelProvider{ | |||
typ: typ, | |||
subType: subType, | |||
secretKey: secretKey, | |||
temperature: temperature, | |||
@@ -56,7 +60,12 @@ func getLocalClientFromUrl(authToken string, url string) *openai.Client { | |||
} | |||
func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | |||
client := getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||
var client *openai.Client | |||
if p.typ == "Local" { | |||
client = getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||
} else if p.typ == "Azure" { | |||
client = getAzureClientFromToken(p.subType, p.deploymentName, p.secretKey, p.providerUrl, p.apiVersion) | |||
} | |||
ctx := context.Background() | |||
flusher, ok := writer.(http.Flusher) | |||
@@ -70,11 +79,16 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builde | |||
frequencyPenalty := p.frequencyPenalty | |||
presencePenalty := p.presencePenalty | |||
respStream, err := client.CreateCompletionStream( | |||
respStream, err := client.CreateChatCompletionStream( | |||
ctx, | |||
openai.CompletionRequest{ | |||
Model: model, | |||
Prompt: question, | |||
openai.ChatCompletionRequest{ | |||
Model: model, | |||
Messages: []openai.ChatCompletionMessage{ | |||
{ | |||
Role: openai.ChatMessageRoleUser, | |||
Content: question, | |||
}, | |||
}, | |||
Stream: true, | |||
Temperature: temperature, | |||
TopP: topP, | |||
@@ -97,7 +111,7 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builde | |||
return streamErr | |||
} | |||
data := completion.Choices[0].Text | |||
data := completion.Choices[0].Delta.Content | |||
if isLeadingReturn && len(data) != 0 { | |||
if strings.Count(data, "\n") == len(data) { | |||
continue | |||
@@ -23,7 +23,7 @@ type ModelProvider interface { | |||
QueryText(question string, writer io.Writer, builder *strings.Builder) error | |||
} | |||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string) (ModelProvider, error) { | |||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string, apiVersion string) (ModelProvider, error) { | |||
var p ModelProvider | |||
var err error | |||
if typ == "OpenAI" { | |||
@@ -43,7 +43,9 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||
} else if typ == "MiniMax" { | |||
p, err = NewMiniMaxModelProvider(subType, clientId, clientSecret, temperature) | |||
} else if typ == "Local" { | |||
p, err = NewLocalModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) | |||
p, err = NewLocalModelProvider(typ, subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) | |||
} else if typ == "Azure" { | |||
p, err = NewAzureModelProvider(typ, subType, clientId, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl, apiVersion) | |||
} | |||
if err != nil { | |||
@@ -36,6 +36,7 @@ type Provider struct { | |||
ClientId string `xorm:"varchar(100)" json:"clientId"` | |||
ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"` | |||
ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` | |||
ApiVersion string `xorm:"varchar(100)" json:"apiVersion"` | |||
Temperature float32 `xorm:"float" json:"temperature"` | |||
TopP float32 `xorm:"float" json:"topP"` | |||
@@ -211,7 +212,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { | |||
} | |||
func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl) | |||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl, p.ApiVersion) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -105,6 +105,14 @@ class ProviderEditPage extends React.Component { | |||
); | |||
} | |||
handleTagChange = (key, value) => { | |||
if (Array.isArray(value) && value.length > 0) { | |||
this.updateProviderField(key, value[value.length - 1]); | |||
} else { | |||
this.updateProviderField(key, value); | |||
} | |||
}; | |||
renderProvider() { | |||
return ( | |||
<Card size="small" title={ | |||
@@ -186,6 +194,8 @@ class ProviderEditPage extends React.Component { | |||
this.updateProviderField("subType", "chatglm2-6b"); | |||
} else if (value === "Local") { | |||
this.updateProviderField("subType", "custom-model"); | |||
} else if (value === "Azure") { | |||
this.updateProviderField("subType", "gpt-4"); | |||
} | |||
} else if (this.state.provider.category === "Embedding") { | |||
if (value === "OpenAI") { | |||
@@ -480,6 +490,35 @@ class ProviderEditPage extends React.Component { | |||
</> | |||
) : null | |||
} | |||
{ | |||
((this.state.provider.category === "Model") && this.state.provider.type === "Azure") ? ( | |||
<> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:Deployment Name")}: | |||
</Col> | |||
<Col span={22} > | |||
<Input value={this.state.provider.clientId} onChange={e => { | |||
this.updateProviderField("clientId", e.target.value); | |||
}} /> | |||
</Col> | |||
</Row> | |||
<Row style={{marginTop: "20px"}}> | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("provider:API Version")}: | |||
</Col> | |||
<Col span={22} > | |||
<Select virtual={false} mode="tags" style={{width: "100%"}} | |||
value={this.state.provider.apiVersion} | |||
onSelect={(value) => {this.handleTagChange("apiVersion", value);}} | |||
onChange={(value) => {this.handleTagChange("apiVersion", value);}} | |||
options={Setting.getProviderAzureApiVersionOptions().map((item) => Setting.getOption(item.name, item.id))} | |||
/> | |||
</Col> | |||
</Row> | |||
</> | |||
) : null | |||
} | |||
<Row style={{marginTop: "20px"}} > | |||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
{i18next.t("general:Provider URL")}: | |||
@@ -64,6 +64,7 @@ class ProviderListPage extends React.Component { | |||
frequencyPenalty: 0, | |||
presencePenalty: 0, | |||
providerUrl: "https://platform.openai.com/account/api-keys", | |||
apiVersion: "", | |||
}; | |||
} | |||
@@ -648,6 +648,7 @@ export function getProviderTypeOptions(category) { | |||
{id: "ChatGLM", name: "ChatGLM"}, | |||
{id: "MiniMax", name: "MiniMax"}, | |||
{id: "Local", name: "Local"}, | |||
{id: "Azure", name: "Azure"}, | |||
] | |||
); | |||
} else if (category === "Embedding") { | |||
@@ -665,57 +666,61 @@ export function getProviderTypeOptions(category) { | |||
} | |||
} | |||
const openaiModels = [ | |||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
{id: "gpt-4", name: "gpt-4"}, | |||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||
{id: "text-curie-001", name: "text-curie-001"}, | |||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||
{id: "text-ada-001", name: "text-ada-001"}, | |||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
{id: "davinci", name: "davinci"}, | |||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
{id: "curie", name: "curie"}, | |||
{id: "ada", name: "ada"}, | |||
{id: "babbage", name: "babbage"}, | |||
]; | |||
const openaiEmbeddings = [ | |||
{id: "1", name: "AdaSimilarity"}, | |||
{id: "2", name: "BabbageSimilarity"}, | |||
{id: "3", name: "CurieSimilarity"}, | |||
{id: "4", name: "DavinciSimilarity"}, | |||
{id: "5", name: "AdaSearchDocument"}, | |||
{id: "6", name: "AdaSearchQuery"}, | |||
{id: "7", name: "BabbageSearchDocument"}, | |||
{id: "8", name: "BabbageSearchQuery"}, | |||
{id: "9", name: "CurieSearchDocument"}, | |||
{id: "10", name: "CurieSearchQuery"}, | |||
{id: "11", name: "DavinciSearchDocument"}, | |||
{id: "12", name: "DavinciSearchQuery"}, | |||
{id: "13", name: "AdaCodeSearchCode"}, | |||
{id: "14", name: "AdaCodeSearchText"}, | |||
{id: "15", name: "BabbageCodeSearchCode"}, | |||
{id: "16", name: "BabbageCodeSearchText"}, | |||
{id: "17", name: "AdaEmbeddingV2"}, | |||
]; | |||
export function getProviderSubTypeOptions(category, type) { | |||
if (type === "OpenAI") { | |||
if (category === "Model") { | |||
return ( | |||
[ | |||
{id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
{id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
{id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
{id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
{id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
{id: "gpt-4", name: "gpt-4"}, | |||
{id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
{id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
{id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
{id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
{id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
{id: "text-davinci-003", name: "text-davinci-003"}, | |||
{id: "text-davinci-002", name: "text-davinci-002"}, | |||
{id: "text-curie-001", name: "text-curie-001"}, | |||
{id: "text-babbage-001", name: "text-babbage-001"}, | |||
{id: "text-ada-001", name: "text-ada-001"}, | |||
{id: "text-davinci-001", name: "text-davinci-001"}, | |||
{id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
{id: "davinci", name: "davinci"}, | |||
{id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
{id: "curie", name: "curie"}, | |||
{id: "ada", name: "ada"}, | |||
{id: "babbage", name: "babbage"}, | |||
] | |||
openaiModels | |||
); | |||
} else if (category === "Embedding") { | |||
return ( | |||
[ | |||
{id: "1", name: "AdaSimilarity"}, | |||
{id: "2", name: "BabbageSimilarity"}, | |||
{id: "3", name: "CurieSimilarity"}, | |||
{id: "4", name: "DavinciSimilarity"}, | |||
{id: "5", name: "AdaSearchDocument"}, | |||
{id: "6", name: "AdaSearchQuery"}, | |||
{id: "7", name: "BabbageSearchDocument"}, | |||
{id: "8", name: "BabbageSearchQuery"}, | |||
{id: "9", name: "CurieSearchDocument"}, | |||
{id: "10", name: "CurieSearchQuery"}, | |||
{id: "11", name: "DavinciSearchDocument"}, | |||
{id: "12", name: "DavinciSearchQuery"}, | |||
{id: "13", name: "AdaCodeSearchCode"}, | |||
{id: "14", name: "AdaCodeSearchText"}, | |||
{id: "15", name: "BabbageCodeSearchCode"}, | |||
{id: "16", name: "BabbageCodeSearchText"}, | |||
{id: "17", name: "AdaEmbeddingV2"}, | |||
] | |||
openaiEmbeddings | |||
); | |||
} else { | |||
return []; | |||
@@ -845,7 +850,24 @@ export function getProviderSubTypeOptions(category, type) { | |||
} else { | |||
return []; | |||
} | |||
} else if (type === "Azure") { | |||
if (category === "Model") { | |||
return ( | |||
openaiModels | |||
); | |||
} | |||
} else { | |||
return []; | |||
} | |||
} | |||
export function getProviderAzureApiVersionOptions() { | |||
return ([ | |||
{id: "", name: ""}, | |||
{id: "2023-03-15-preview", name: "2023-03-15-preview"}, | |||
{id: "2023-05-15", name: "2023-05-15"}, | |||
{id: "2023-06-01-preview", name: "2023-06-01-preview"}, | |||
{id: "2023-07-01-preview", name: "2023-07-01-preview"}, | |||
{id: "2023-08-01-preview", name: "2023-08-01-preview"}, | |||
]); | |||
} |