* feat: support Azure model provider * fix: convert apiVersion to stringmaster
| @@ -0,0 +1,55 @@ | |||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| package model | |||
| import ( | |||
| "github.com/casbin/casibase/proxy" | |||
| "github.com/sashabaranov/go-openai" | |||
| ) | |||
| func NewAzureModelProvider(typ string, subType string, deploymentName string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string, apiVersion string) (*LocalModelProvider, error) { | |||
| p := &LocalModelProvider{ | |||
| typ: typ, | |||
| subType: subType, | |||
| deploymentName: deploymentName, | |||
| secretKey: secretKey, | |||
| temperature: temperature, | |||
| topP: topP, | |||
| frequencyPenalty: frequencyPenalty, | |||
| presencePenalty: presencePenalty, | |||
| providerUrl: providerUrl, | |||
| apiVersion: apiVersion, | |||
| } | |||
| return p, nil | |||
| } | |||
| func getAzureClientFromToken(subtype string, deploymentName string, authToken string, url string, apiVersion string) *openai.Client { | |||
| config := openai.DefaultAzureConfig(authToken, url) | |||
| config.HTTPClient = proxy.ProxyHttpClient | |||
| if apiVersion != "" { | |||
| config.APIVersion = apiVersion | |||
| } | |||
| if deploymentName != "" { | |||
| config.AzureModelMapperFunc = func(model string) string { | |||
| azureModelMapping := map[string]string{ | |||
| subtype: deploymentName, | |||
| } | |||
| return azureModelMapping[model] | |||
| } | |||
| } | |||
| c := openai.NewClientWithConfig(config) | |||
| return c | |||
| } | |||
| @@ -25,17 +25,21 @@ import ( | |||
| ) | |||
| type LocalModelProvider struct { | |||
| typ string | |||
| subType string | |||
| deploymentName string | |||
| secretKey string | |||
| temperature float32 | |||
| topP float32 | |||
| frequencyPenalty float32 | |||
| presencePenalty float32 | |||
| providerUrl string | |||
| apiVersion string | |||
| } | |||
| func NewLocalModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { | |||
| func NewLocalModelProvider(typ string, subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { | |||
| p := &LocalModelProvider{ | |||
| typ: typ, | |||
| subType: subType, | |||
| secretKey: secretKey, | |||
| temperature: temperature, | |||
| @@ -56,7 +60,12 @@ func getLocalClientFromUrl(authToken string, url string) *openai.Client { | |||
| } | |||
| func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | |||
| client := getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||
| var client *openai.Client | |||
| if p.typ == "Local" { | |||
| client = getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||
| } else if p.typ == "Azure" { | |||
| client = getAzureClientFromToken(p.subType, p.deploymentName, p.secretKey, p.providerUrl, p.apiVersion) | |||
| } | |||
| ctx := context.Background() | |||
| flusher, ok := writer.(http.Flusher) | |||
| @@ -70,11 +79,16 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builde | |||
| frequencyPenalty := p.frequencyPenalty | |||
| presencePenalty := p.presencePenalty | |||
| respStream, err := client.CreateCompletionStream( | |||
| respStream, err := client.CreateChatCompletionStream( | |||
| ctx, | |||
| openai.CompletionRequest{ | |||
| Model: model, | |||
| Prompt: question, | |||
| openai.ChatCompletionRequest{ | |||
| Model: model, | |||
| Messages: []openai.ChatCompletionMessage{ | |||
| { | |||
| Role: openai.ChatMessageRoleUser, | |||
| Content: question, | |||
| }, | |||
| }, | |||
| Stream: true, | |||
| Temperature: temperature, | |||
| TopP: topP, | |||
| @@ -97,7 +111,7 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builde | |||
| return streamErr | |||
| } | |||
| data := completion.Choices[0].Text | |||
| data := completion.Choices[0].Delta.Content | |||
| if isLeadingReturn && len(data) != 0 { | |||
| if strings.Count(data, "\n") == len(data) { | |||
| continue | |||
| @@ -23,7 +23,7 @@ type ModelProvider interface { | |||
| QueryText(question string, writer io.Writer, builder *strings.Builder) error | |||
| } | |||
| func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string) (ModelProvider, error) { | |||
| func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string, apiVersion string) (ModelProvider, error) { | |||
| var p ModelProvider | |||
| var err error | |||
| if typ == "OpenAI" { | |||
| @@ -43,7 +43,9 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||
| } else if typ == "MiniMax" { | |||
| p, err = NewMiniMaxModelProvider(subType, clientId, clientSecret, temperature) | |||
| } else if typ == "Local" { | |||
| p, err = NewLocalModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) | |||
| p, err = NewLocalModelProvider(typ, subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) | |||
| } else if typ == "Azure" { | |||
| p, err = NewAzureModelProvider(typ, subType, clientId, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl, apiVersion) | |||
| } | |||
| if err != nil { | |||
| @@ -36,6 +36,7 @@ type Provider struct { | |||
| ClientId string `xorm:"varchar(100)" json:"clientId"` | |||
| ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"` | |||
| ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` | |||
| ApiVersion string `xorm:"varchar(100)" json:"apiVersion"` | |||
| Temperature float32 `xorm:"float" json:"temperature"` | |||
| TopP float32 `xorm:"float" json:"topP"` | |||
| @@ -211,7 +212,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { | |||
| } | |||
| func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||
| pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl) | |||
| pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl, p.ApiVersion) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| @@ -105,6 +105,14 @@ class ProviderEditPage extends React.Component { | |||
| ); | |||
| } | |||
| handleTagChange = (key, value) => { | |||
| if (Array.isArray(value) && value.length > 0) { | |||
| this.updateProviderField(key, value[value.length - 1]); | |||
| } else { | |||
| this.updateProviderField(key, value); | |||
| } | |||
| }; | |||
| renderProvider() { | |||
| return ( | |||
| <Card size="small" title={ | |||
| @@ -186,6 +194,8 @@ class ProviderEditPage extends React.Component { | |||
| this.updateProviderField("subType", "chatglm2-6b"); | |||
| } else if (value === "Local") { | |||
| this.updateProviderField("subType", "custom-model"); | |||
| } else if (value === "Azure") { | |||
| this.updateProviderField("subType", "gpt-4"); | |||
| } | |||
| } else if (this.state.provider.category === "Embedding") { | |||
| if (value === "OpenAI") { | |||
| @@ -480,6 +490,35 @@ class ProviderEditPage extends React.Component { | |||
| </> | |||
| ) : null | |||
| } | |||
| { | |||
| ((this.state.provider.category === "Model") && this.state.provider.type === "Azure") ? ( | |||
| <> | |||
| <Row style={{marginTop: "20px"}}> | |||
| <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
| {i18next.t("provider:Deployment Name")}: | |||
| </Col> | |||
| <Col span={22} > | |||
| <Input value={this.state.provider.clientId} onChange={e => { | |||
| this.updateProviderField("clientId", e.target.value); | |||
| }} /> | |||
| </Col> | |||
| </Row> | |||
| <Row style={{marginTop: "20px"}}> | |||
| <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
| {i18next.t("provider:API Version")}: | |||
| </Col> | |||
| <Col span={22} > | |||
| <Select virtual={false} mode="tags" style={{width: "100%"}} | |||
| value={this.state.provider.apiVersion} | |||
| onSelect={(value) => {this.handleTagChange("apiVersion", value);}} | |||
| onChange={(value) => {this.handleTagChange("apiVersion", value);}} | |||
| options={Setting.getProviderAzureApiVersionOptions().map((item) => Setting.getOption(item.name, item.id))} | |||
| /> | |||
| </Col> | |||
| </Row> | |||
| </> | |||
| ) : null | |||
| } | |||
| <Row style={{marginTop: "20px"}} > | |||
| <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> | |||
| {i18next.t("general:Provider URL")}: | |||
| @@ -64,6 +64,7 @@ class ProviderListPage extends React.Component { | |||
| frequencyPenalty: 0, | |||
| presencePenalty: 0, | |||
| providerUrl: "https://platform.openai.com/account/api-keys", | |||
| apiVersion: "", | |||
| }; | |||
| } | |||
| @@ -648,6 +648,7 @@ export function getProviderTypeOptions(category) { | |||
| {id: "ChatGLM", name: "ChatGLM"}, | |||
| {id: "MiniMax", name: "MiniMax"}, | |||
| {id: "Local", name: "Local"}, | |||
| {id: "Azure", name: "Azure"}, | |||
| ] | |||
| ); | |||
| } else if (category === "Embedding") { | |||
| @@ -665,57 +666,61 @@ export function getProviderTypeOptions(category) { | |||
| } | |||
| } | |||
| const openaiModels = [ | |||
| {id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
| {id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
| {id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
| {id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
| {id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
| {id: "gpt-4", name: "gpt-4"}, | |||
| {id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
| {id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
| {id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
| {id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
| {id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
| {id: "text-davinci-003", name: "text-davinci-003"}, | |||
| {id: "text-davinci-002", name: "text-davinci-002"}, | |||
| {id: "text-curie-001", name: "text-curie-001"}, | |||
| {id: "text-babbage-001", name: "text-babbage-001"}, | |||
| {id: "text-ada-001", name: "text-ada-001"}, | |||
| {id: "text-davinci-001", name: "text-davinci-001"}, | |||
| {id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
| {id: "davinci", name: "davinci"}, | |||
| {id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
| {id: "curie", name: "curie"}, | |||
| {id: "ada", name: "ada"}, | |||
| {id: "babbage", name: "babbage"}, | |||
| ]; | |||
| const openaiEmbeddings = [ | |||
| {id: "1", name: "AdaSimilarity"}, | |||
| {id: "2", name: "BabbageSimilarity"}, | |||
| {id: "3", name: "CurieSimilarity"}, | |||
| {id: "4", name: "DavinciSimilarity"}, | |||
| {id: "5", name: "AdaSearchDocument"}, | |||
| {id: "6", name: "AdaSearchQuery"}, | |||
| {id: "7", name: "BabbageSearchDocument"}, | |||
| {id: "8", name: "BabbageSearchQuery"}, | |||
| {id: "9", name: "CurieSearchDocument"}, | |||
| {id: "10", name: "CurieSearchQuery"}, | |||
| {id: "11", name: "DavinciSearchDocument"}, | |||
| {id: "12", name: "DavinciSearchQuery"}, | |||
| {id: "13", name: "AdaCodeSearchCode"}, | |||
| {id: "14", name: "AdaCodeSearchText"}, | |||
| {id: "15", name: "BabbageCodeSearchCode"}, | |||
| {id: "16", name: "BabbageCodeSearchText"}, | |||
| {id: "17", name: "AdaEmbeddingV2"}, | |||
| ]; | |||
| export function getProviderSubTypeOptions(category, type) { | |||
| if (type === "OpenAI") { | |||
| if (category === "Model") { | |||
| return ( | |||
| [ | |||
| {id: "gpt-4-32k-0613", name: "gpt-4-32k-0613"}, | |||
| {id: "gpt-4-32k-0314", name: "gpt-4-32k-0314"}, | |||
| {id: "gpt-4-32k", name: "gpt-4-32k"}, | |||
| {id: "gpt-4-0613", name: "gpt-4-0613"}, | |||
| {id: "gpt-4-0314", name: "gpt-4-0314"}, | |||
| {id: "gpt-4", name: "gpt-4"}, | |||
| {id: "gpt-3.5-turbo-0613", name: "gpt-3.5-turbo-0613"}, | |||
| {id: "gpt-3.5-turbo-0301", name: "gpt-3.5-turbo-0301"}, | |||
| {id: "gpt-3.5-turbo-16k", name: "gpt-3.5-turbo-16k"}, | |||
| {id: "gpt-3.5-turbo-16k-0613", name: "gpt-3.5-turbo-16k-0613"}, | |||
| {id: "gpt-3.5-turbo", name: "gpt-3.5-turbo"}, | |||
| {id: "text-davinci-003", name: "text-davinci-003"}, | |||
| {id: "text-davinci-002", name: "text-davinci-002"}, | |||
| {id: "text-curie-001", name: "text-curie-001"}, | |||
| {id: "text-babbage-001", name: "text-babbage-001"}, | |||
| {id: "text-ada-001", name: "text-ada-001"}, | |||
| {id: "text-davinci-001", name: "text-davinci-001"}, | |||
| {id: "davinci-instruct-beta", name: "davinci-instruct-beta"}, | |||
| {id: "davinci", name: "davinci"}, | |||
| {id: "curie-instruct-beta", name: "curie-instruct-beta"}, | |||
| {id: "curie", name: "curie"}, | |||
| {id: "ada", name: "ada"}, | |||
| {id: "babbage", name: "babbage"}, | |||
| ] | |||
| openaiModels | |||
| ); | |||
| } else if (category === "Embedding") { | |||
| return ( | |||
| [ | |||
| {id: "1", name: "AdaSimilarity"}, | |||
| {id: "2", name: "BabbageSimilarity"}, | |||
| {id: "3", name: "CurieSimilarity"}, | |||
| {id: "4", name: "DavinciSimilarity"}, | |||
| {id: "5", name: "AdaSearchDocument"}, | |||
| {id: "6", name: "AdaSearchQuery"}, | |||
| {id: "7", name: "BabbageSearchDocument"}, | |||
| {id: "8", name: "BabbageSearchQuery"}, | |||
| {id: "9", name: "CurieSearchDocument"}, | |||
| {id: "10", name: "CurieSearchQuery"}, | |||
| {id: "11", name: "DavinciSearchDocument"}, | |||
| {id: "12", name: "DavinciSearchQuery"}, | |||
| {id: "13", name: "AdaCodeSearchCode"}, | |||
| {id: "14", name: "AdaCodeSearchText"}, | |||
| {id: "15", name: "BabbageCodeSearchCode"}, | |||
| {id: "16", name: "BabbageCodeSearchText"}, | |||
| {id: "17", name: "AdaEmbeddingV2"}, | |||
| ] | |||
| openaiEmbeddings | |||
| ); | |||
| } else { | |||
| return []; | |||
| @@ -845,7 +850,24 @@ export function getProviderSubTypeOptions(category, type) { | |||
| } else { | |||
| return []; | |||
| } | |||
| } else if (type === "Azure") { | |||
| if (category === "Model") { | |||
| return ( | |||
| openaiModels | |||
| ); | |||
| } | |||
| } else { | |||
| return []; | |||
| } | |||
| } | |||
| export function getProviderAzureApiVersionOptions() { | |||
| return ([ | |||
| {id: "", name: ""}, | |||
| {id: "2023-03-15-preview", name: "2023-03-15-preview"}, | |||
| {id: "2023-05-15", name: "2023-05-15"}, | |||
| {id: "2023-06-01-preview", name: "2023-06-01-preview"}, | |||
| {id: "2023-07-01-preview", name: "2023-07-01-preview"}, | |||
| {id: "2023-08-01-preview", name: "2023-08-01-preview"}, | |||
| ]); | |||
| } | |||