* feat: support Private model chat * fix: rename Private model to Local model * fix: rename Private model to Local modelHEAD
| @@ -0,0 +1,119 @@ | |||||
| // Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| package model | |||||
| import ( | |||||
| "context" | |||||
| "fmt" | |||||
| "io" | |||||
| "net/http" | |||||
| "strings" | |||||
| "github.com/sashabaranov/go-openai" | |||||
| ) | |||||
| type LocalModelProvider struct { | |||||
| subType string | |||||
| secretKey string | |||||
| temperature float32 | |||||
| topP float32 | |||||
| frequencyPenalty float32 | |||||
| presencePenalty float32 | |||||
| providerUrl string | |||||
| } | |||||
| func NewLocalModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { | |||||
| p := &LocalModelProvider{ | |||||
| subType: subType, | |||||
| secretKey: secretKey, | |||||
| temperature: temperature, | |||||
| topP: topP, | |||||
| frequencyPenalty: frequencyPenalty, | |||||
| presencePenalty: presencePenalty, | |||||
| providerUrl: providerUrl, | |||||
| } | |||||
| return p, nil | |||||
| } | |||||
| func getLocalClientFromUrl(authToken string, url string) *openai.Client { | |||||
| config := openai.DefaultConfig(authToken) | |||||
| config.BaseURL = url | |||||
| c := openai.NewClientWithConfig(config) | |||||
| return c | |||||
| } | |||||
| func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | |||||
| client := getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||||
| ctx := context.Background() | |||||
| flusher, ok := writer.(http.Flusher) | |||||
| if !ok { | |||||
| return fmt.Errorf("writer does not implement http.Flusher") | |||||
| } | |||||
| model := p.subType | |||||
| temperature := p.temperature | |||||
| topP := p.topP | |||||
| frequencyPenalty := p.frequencyPenalty | |||||
| presencePenalty := p.presencePenalty | |||||
| respStream, err := client.CreateCompletionStream( | |||||
| ctx, | |||||
| openai.CompletionRequest{ | |||||
| Model: model, | |||||
| Prompt: question, | |||||
| Stream: true, | |||||
| Temperature: temperature, | |||||
| TopP: topP, | |||||
| FrequencyPenalty: frequencyPenalty, | |||||
| PresencePenalty: presencePenalty, | |||||
| }, | |||||
| ) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| defer respStream.Close() | |||||
| isLeadingReturn := true | |||||
| for { | |||||
| completion, streamErr := respStream.Recv() | |||||
| if streamErr != nil { | |||||
| if streamErr == io.EOF { | |||||
| break | |||||
| } | |||||
| return streamErr | |||||
| } | |||||
| data := completion.Choices[0].Text | |||||
| if isLeadingReturn && len(data) != 0 { | |||||
| if strings.Count(data, "\n") == len(data) { | |||||
| continue | |||||
| } else { | |||||
| isLeadingReturn = false | |||||
| } | |||||
| } | |||||
| // Write the streamed data as Server-Sent Events | |||||
| if _, err = fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil { | |||||
| return err | |||||
| } | |||||
| flusher.Flush() | |||||
| // Append the response to the strings.Builder | |||||
| builder.WriteString(data) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| @@ -23,7 +23,7 @@ type ModelProvider interface { | |||||
| QueryText(question string, writer io.Writer, builder *strings.Builder) error | QueryText(question string, writer io.Writer, builder *strings.Builder) error | ||||
| } | } | ||||
| func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { | |||||
| func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string) (ModelProvider, error) { | |||||
| var p ModelProvider | var p ModelProvider | ||||
| var err error | var err error | ||||
| if typ == "OpenAI" { | if typ == "OpenAI" { | ||||
| @@ -40,6 +40,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||||
| p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | ||||
| } else if typ == "ChatGLM" { | } else if typ == "ChatGLM" { | ||||
| p, err = NewChatGLMModelProvider(subType, clientSecret) | p, err = NewChatGLMModelProvider(subType, clientSecret) | ||||
| } else if typ == "Local" { | |||||
| p, err = NewLocalModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) | |||||
| } | } | ||||
| if err != nil { | if err != nil { | ||||
| @@ -211,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { | |||||
| } | } | ||||
| func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | ||||
| pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty) | |||||
| pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| @@ -646,6 +646,7 @@ export function getProviderTypeOptions(category) { | |||||
| {id: "Ernie", name: "Ernie"}, | {id: "Ernie", name: "Ernie"}, | ||||
| {id: "iFlytek", name: "iFlytek"}, | {id: "iFlytek", name: "iFlytek"}, | ||||
| {id: "ChatGLM", name: "ChatGLM"}, | {id: "ChatGLM", name: "ChatGLM"}, | ||||
| {id: "Local", name: "Local"}, | |||||
| ] | ] | ||||
| ); | ); | ||||
| } else if (category === "Embedding") { | } else if (category === "Embedding") { | ||||
| @@ -820,6 +821,22 @@ export function getProviderSubTypeOptions(category, type) { | |||||
| {id: "chatglm2-6b", name: "chatglm2-6b"}, | {id: "chatglm2-6b", name: "chatglm2-6b"}, | ||||
| ] | ] | ||||
| ); | ); | ||||
| } else if (type === "Local") { | |||||
| if (category === "Model") { | |||||
| return ( | |||||
| [ | |||||
| {id: "custom-model", name: "custom-model"}, | |||||
| ] | |||||
| ); | |||||
| } else if (category === "Embedding") { | |||||
| return ( | |||||
| [ | |||||
| {id: "custom-embedding", name: "custom-embedding"}, | |||||
| ] | |||||
| ); | |||||
| } else { | |||||
| return []; | |||||
| } | |||||
| } else { | } else { | ||||
| return []; | return []; | ||||
| } | } | ||||