* feat: support Private model chat * fix: rename Private model to Local model * fix: rename Private model to Local modelHEAD
@@ -0,0 +1,119 @@ | |||||
// Copyright 2023 The casbin Authors. All Rights Reserved. | |||||
// | |||||
// Licensed under the Apache License, Version 2.0 (the "License"); | |||||
// you may not use this file except in compliance with the License. | |||||
// You may obtain a copy of the License at | |||||
// | |||||
// http://www.apache.org/licenses/LICENSE-2.0 | |||||
// | |||||
// Unless required by applicable law or agreed to in writing, software | |||||
// distributed under the License is distributed on an "AS IS" BASIS, | |||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
// See the License for the specific language governing permissions and | |||||
// limitations under the License. | |||||
package model | |||||
import ( | |||||
"context" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"strings" | |||||
"github.com/sashabaranov/go-openai" | |||||
) | |||||
type LocalModelProvider struct { | |||||
subType string | |||||
secretKey string | |||||
temperature float32 | |||||
topP float32 | |||||
frequencyPenalty float32 | |||||
presencePenalty float32 | |||||
providerUrl string | |||||
} | |||||
func NewLocalModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { | |||||
p := &LocalModelProvider{ | |||||
subType: subType, | |||||
secretKey: secretKey, | |||||
temperature: temperature, | |||||
topP: topP, | |||||
frequencyPenalty: frequencyPenalty, | |||||
presencePenalty: presencePenalty, | |||||
providerUrl: providerUrl, | |||||
} | |||||
return p, nil | |||||
} | |||||
func getLocalClientFromUrl(authToken string, url string) *openai.Client { | |||||
config := openai.DefaultConfig(authToken) | |||||
config.BaseURL = url | |||||
c := openai.NewClientWithConfig(config) | |||||
return c | |||||
} | |||||
func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { | |||||
client := getLocalClientFromUrl(p.secretKey, p.providerUrl) | |||||
ctx := context.Background() | |||||
flusher, ok := writer.(http.Flusher) | |||||
if !ok { | |||||
return fmt.Errorf("writer does not implement http.Flusher") | |||||
} | |||||
model := p.subType | |||||
temperature := p.temperature | |||||
topP := p.topP | |||||
frequencyPenalty := p.frequencyPenalty | |||||
presencePenalty := p.presencePenalty | |||||
respStream, err := client.CreateCompletionStream( | |||||
ctx, | |||||
openai.CompletionRequest{ | |||||
Model: model, | |||||
Prompt: question, | |||||
Stream: true, | |||||
Temperature: temperature, | |||||
TopP: topP, | |||||
FrequencyPenalty: frequencyPenalty, | |||||
PresencePenalty: presencePenalty, | |||||
}, | |||||
) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
defer respStream.Close() | |||||
isLeadingReturn := true | |||||
for { | |||||
completion, streamErr := respStream.Recv() | |||||
if streamErr != nil { | |||||
if streamErr == io.EOF { | |||||
break | |||||
} | |||||
return streamErr | |||||
} | |||||
data := completion.Choices[0].Text | |||||
if isLeadingReturn && len(data) != 0 { | |||||
if strings.Count(data, "\n") == len(data) { | |||||
continue | |||||
} else { | |||||
isLeadingReturn = false | |||||
} | |||||
} | |||||
// Write the streamed data as Server-Sent Events | |||||
if _, err = fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil { | |||||
return err | |||||
} | |||||
flusher.Flush() | |||||
// Append the response to the strings.Builder | |||||
builder.WriteString(data) | |||||
} | |||||
return nil | |||||
} |
@@ -23,7 +23,7 @@ type ModelProvider interface { | |||||
QueryText(question string, writer io.Writer, builder *strings.Builder) error | QueryText(question string, writer io.Writer, builder *strings.Builder) error | ||||
} | } | ||||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { | |||||
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string) (ModelProvider, error) { | |||||
var p ModelProvider | var p ModelProvider | ||||
var err error | var err error | ||||
if typ == "OpenAI" { | if typ == "OpenAI" { | ||||
@@ -40,6 +40,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret | |||||
p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) | ||||
} else if typ == "ChatGLM" { | } else if typ == "ChatGLM" { | ||||
p, err = NewChatGLMModelProvider(subType, clientSecret) | p, err = NewChatGLMModelProvider(subType, clientSecret) | ||||
} else if typ == "Local" { | |||||
p, err = NewLocalModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) | |||||
} | } | ||||
if err != nil { | if err != nil { | ||||
@@ -211,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { | |||||
} | } | ||||
func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | ||||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty) | |||||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -646,6 +646,7 @@ export function getProviderTypeOptions(category) { | |||||
{id: "Ernie", name: "Ernie"}, | {id: "Ernie", name: "Ernie"}, | ||||
{id: "iFlytek", name: "iFlytek"}, | {id: "iFlytek", name: "iFlytek"}, | ||||
{id: "ChatGLM", name: "ChatGLM"}, | {id: "ChatGLM", name: "ChatGLM"}, | ||||
{id: "Local", name: "Local"}, | |||||
] | ] | ||||
); | ); | ||||
} else if (category === "Embedding") { | } else if (category === "Embedding") { | ||||
@@ -820,6 +821,22 @@ export function getProviderSubTypeOptions(category, type) { | |||||
{id: "chatglm2-6b", name: "chatglm2-6b"}, | {id: "chatglm2-6b", name: "chatglm2-6b"}, | ||||
] | ] | ||||
); | ); | ||||
} else if (type === "Local") { | |||||
if (category === "Model") { | |||||
return ( | |||||
[ | |||||
{id: "custom-model", name: "custom-model"}, | |||||
] | |||||
); | |||||
} else if (category === "Embedding") { | |||||
return ( | |||||
[ | |||||
{id: "custom-embedding", name: "custom-embedding"}, | |||||
] | |||||
); | |||||
} else { | |||||
return []; | |||||
} | |||||
} else { | } else { | ||||
return []; | return []; | ||||
} | } | ||||