From d2bd6cb3ab0c9887163c2ce3d6eb263e39ddb4f3 Mon Sep 17 00:00:00 2001 From: Kelvin Chiu Date: Mon, 25 Sep 2023 08:29:42 +0800 Subject: [PATCH] feat: support Local model chat using openai format (#651) * feat: support Private model chat * fix: rename Private model to Local model * fix: rename Private model to Local model --- model/local.go | 119 +++++++++++++++++++++++++++++++++++++++++++++ model/provider.go | 4 +- object/provider.go | 2 +- web/src/Setting.js | 17 +++++++ 4 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 model/local.go diff --git a/model/local.go b/model/local.go new file mode 100644 index 0000000..c882735 --- /dev/null +++ b/model/local.go @@ -0,0 +1,119 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/sashabaranov/go-openai" +) + +type LocalModelProvider struct { + subType string + secretKey string + temperature float32 + topP float32 + frequencyPenalty float32 + presencePenalty float32 + providerUrl string +} + +func NewLocalModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32, providerUrl string) (*LocalModelProvider, error) { + p := &LocalModelProvider{ + subType: subType, + secretKey: secretKey, + temperature: temperature, + topP: topP, + frequencyPenalty: frequencyPenalty, + presencePenalty: presencePenalty, + providerUrl: providerUrl, + } + return p, nil +} + +func getLocalClientFromUrl(authToken string, url string) *openai.Client { + config := openai.DefaultConfig(authToken) + config.BaseURL = url + + c := openai.NewClientWithConfig(config) + return c +} + +func (p *LocalModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { + client := getLocalClientFromUrl(p.secretKey, p.providerUrl) + + ctx := context.Background() + flusher, ok := writer.(http.Flusher) + if !ok { + return fmt.Errorf("writer does not implement http.Flusher") + } + + model := p.subType + temperature := p.temperature + topP := p.topP + frequencyPenalty := p.frequencyPenalty + presencePenalty := p.presencePenalty + + respStream, err := client.CreateCompletionStream( + ctx, + openai.CompletionRequest{ + Model: model, + Prompt: question, + Stream: true, + Temperature: temperature, + TopP: topP, + FrequencyPenalty: frequencyPenalty, + PresencePenalty: presencePenalty, + }, + ) + if err != nil { + return err + } + defer respStream.Close() + + isLeadingReturn := true + for { + completion, streamErr := respStream.Recv() + if streamErr != nil { + if streamErr == io.EOF { + break + } + return streamErr + } + + data := completion.Choices[0].Text + if isLeadingReturn && len(data) != 0 { + if strings.Count(data, "\n") == len(data) { + continue + } else { + isLeadingReturn = false + } + } + + // Write the streamed data as Server-Sent Events + if _, err = fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil { + return err + } + flusher.Flush() + // Append the response to the strings.Builder + builder.WriteString(data) + } + + return nil +} diff --git a/model/provider.go b/model/provider.go index c7d83f2..39e5674 100644 --- a/model/provider.go +++ b/model/provider.go @@ -23,7 +23,7 @@ type ModelProvider interface { QueryText(question string, writer io.Writer, builder *strings.Builder) error } -func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) { +func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, topK int, frequencyPenalty float32, presencePenalty float32, providerUrl string) (ModelProvider, error) { var p ModelProvider var err error if typ == "OpenAI" { @@ -40,6 +40,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret p, err = NewiFlytekModelProvider(subType, clientSecret, temperature, topK) } else if typ == "ChatGLM" { p, err = NewChatGLMModelProvider(subType, clientSecret) + } else if typ == "Local" { + p, err = NewLocalModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty, providerUrl) } if err != nil { diff --git a/object/provider.go b/object/provider.go index 8cfb334..2b0aa9b 100644 --- a/object/provider.go +++ b/object/provider.go @@ -211,7 +211,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) { } func (p *Provider) GetModelProvider() (model.ModelProvider, error) { - pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty) + pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.TopK, p.FrequencyPenalty, p.PresencePenalty, p.ProviderUrl) if err != nil { return nil, err } diff --git a/web/src/Setting.js b/web/src/Setting.js index 397a141..b312be8 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -646,6 +646,7 @@ export function getProviderTypeOptions(category) { {id: "Ernie", name: "Ernie"}, {id: "iFlytek", name: "iFlytek"}, {id: "ChatGLM", name: "ChatGLM"}, + {id: "Local", name: "Local"}, ] ); } else if (category === "Embedding") { @@ -820,6 +821,22 @@ export function getProviderSubTypeOptions(category, type) { {id: "chatglm2-6b", name: "chatglm2-6b"}, ] ); + } else if (type === "Local") { + if (category === "Model") { + return ( + [ + {id: "custom-model", name: "custom-model"}, + ] + ); + } else if (category === "Embedding") { + return ( + [ + {id: "custom-embedding", name: "custom-embedding"}, + ] + ); + } else { + return []; + } } else { return []; }