From 4fbcafda35990276576583e5fc5afd285df27f73 Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Fri, 29 Sep 2023 18:59:08 +0800 Subject: [PATCH] Improve RefinedWriter code --- controllers/message.go | 6 +-- controllers/message_writer.go | 97 +++++++++++++++++++++++++++++++++++ controllers/util.go | 79 +--------------------------- 3 files changed, 101 insertions(+), 81 deletions(-) create mode 100644 controllers/message_writer.go diff --git a/controllers/message.go b/controllers/message.go index ebcfa07..2c30e3a 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -145,9 +145,9 @@ func (c *ApiController) GetMessageAnswer() { fmt.Printf("Refined Question: [%s]\n", realQuestion) fmt.Printf("Answer: [") - ourWriter := &RefinedWriter{*c.Ctx.ResponseWriter, *NewCleaner(6), []byte{}} + writer := &RefinedWriter{c.Ctx.ResponseWriter, *NewCleaner(6), []byte{}} stringBuilder := &strings.Builder{} - err = modelProviderObj.QueryText(realQuestion, ourWriter, stringBuilder) + err = modelProviderObj.QueryText(realQuestion, writer, stringBuilder) if err != nil { c.ResponseErrorStream(err.Error()) return @@ -162,7 +162,7 @@ func (c *ApiController) GetMessageAnswer() { return } - answer := ourWriter.String() + answer := writer.String() message.Text = answer _, err = object.UpdateMessage(message.GetId(), message) diff --git a/controllers/message_writer.go b/controllers/message_writer.go new file mode 100644 index 0000000..215b8f7 --- /dev/null +++ b/controllers/message_writer.go @@ -0,0 +1,97 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package controllers + +import ( + "fmt" + "net/http" + "strings" +) + +type RefinedWriter struct { + http.ResponseWriter + writerCleaner Cleaner + buf []byte +} + +func newRefinedWriter(w http.ResponseWriter) *RefinedWriter { + return &RefinedWriter{w, *NewCleaner(6), []byte{}} +} + +func (w *RefinedWriter) Write(p []byte) (n int, err error) { + data := strings.TrimRight(strings.TrimLeft(string(p), "event: message\ndata: "), "\n\n") + if w.writerCleaner.cleaned == false && w.writerCleaner.dataTimes < w.writerCleaner.bufferSize { + w.writerCleaner.AddData(data) + if w.writerCleaner.dataTimes == w.writerCleaner.bufferSize { + cleanedData := w.writerCleaner.GetCleanedData() + w.buf = append(w.buf, []byte(cleanedData)...) + return w.ResponseWriter.Write([]byte(fmt.Sprintf("event: message\ndata: %s\n\n", cleanedData))) + } + return 0, nil + } + + w.buf = append(w.buf, []byte(data)...) + return w.ResponseWriter.Write(p) +} + +func (w *RefinedWriter) String() string { + return string(w.buf) +} + +type Cleaner struct { + dataTimes int // Number of times data is added + buffer []string // Buffer of tokens + bufferSize int // Size of the buffer + cleaned bool // Whether the data has been cleaned +} + +func NewCleaner(bufferSize int) *Cleaner { + return &Cleaner{ + dataTimes: 0, + buffer: make([]string, 0, bufferSize), + bufferSize: bufferSize, + cleaned: false, + } +} + +func (c *Cleaner) AddData(data string) { + c.buffer = append(c.buffer, data) + c.dataTimes++ +} + +func (c *Cleaner) GetCleanedData() string { + c.cleaned = true + return cleanString(strings.Join(c.buffer, "")) +} + +func cleanString(data string) string { + data = strings.Replace(data, "?", "", -1) + data = strings.Replace(data, "?", "", -1) + + data = strings.Replace(data, "-", "", -1) + data = strings.Replace(data, "——", "", -1) + + if strings.Contains(data, ":") { + parts := strings.Split(data, ":") + data = parts[len(parts)-1] + } else if strings.Contains(data, ":") { + parts := strings.Split(data, ":") + data = parts[len(parts)-1] + } + + data = strings.TrimSpace(data) + + return data +} diff --git a/controllers/util.go b/controllers/util.go index 0e5e3f9..b6a19b1 100644 --- a/controllers/util.go +++ b/controllers/util.go @@ -14,12 +14,7 @@ package controllers -import ( - "fmt" - "strings" - - "github.com/astaxie/beego/context" -) +import "github.com/astaxie/beego/context" type Response struct { Status string `json:"status"` @@ -92,75 +87,3 @@ func responseError(ctx *context.Context, error string, data ...interface{}) { panic(err) } } - -type RefinedWriter struct { - context.Response - writerCleaner Cleaner - buf []byte -} - -func (w *RefinedWriter) Write(p []byte) (n int, err error) { - data := strings.TrimRight(strings.TrimLeft(string(p), "event: message\ndata: "), "\n\n") - if w.writerCleaner.cleaned == false && w.writerCleaner.dataTimes < w.writerCleaner.bufferSize { - w.writerCleaner.AddData(data) - if w.writerCleaner.dataTimes == w.writerCleaner.bufferSize { - cleanedData := w.writerCleaner.GetCleanedData() - w.buf = append(w.buf, []byte(cleanedData)...) - return w.ResponseWriter.Write([]byte(fmt.Sprintf("event: message\ndata: %s\n\n", cleanedData))) - } - return 0, nil - } - - w.buf = append(w.buf, []byte(data)...) - return w.ResponseWriter.Write(p) -} - -func (w *RefinedWriter) String() string { - return string(w.buf) -} - -type Cleaner struct { - dataTimes int // Number of times data is added - buffer []string // Buffer of tokens - bufferSize int // Size of the buffer - cleaned bool // Whether the data has been cleaned -} - -func NewCleaner(bufferSize int) *Cleaner { - return &Cleaner{ - dataTimes: 0, - buffer: make([]string, 0, bufferSize), - bufferSize: bufferSize, - cleaned: false, - } -} - -func (c *Cleaner) AddData(data string) { - c.buffer = append(c.buffer, data) - c.dataTimes++ -} - -func (c *Cleaner) GetCleanedData() string { - c.cleaned = true - return cleanString(strings.Join(c.buffer, "")) -} - -func cleanString(data string) string { - data = strings.Replace(data, "?", "", -1) - data = strings.Replace(data, "?", "", -1) - - data = strings.Replace(data, "-", "", -1) - data = strings.Replace(data, "——", "", -1) - - if strings.Contains(data, ":") { - parts := strings.Split(data, ":") - data = parts[len(parts)-1] - } else if strings.Contains(data, ":") { - parts := strings.Split(data, ":") - data = parts[len(parts)-1] - } - - data = strings.TrimSpace(data) - - return data -}