diff --git a/ai/ai.go b/ai/ai.go index 381bd29..0178a23 100644 --- a/ai/ai.go +++ b/ai/ai.go @@ -139,3 +139,10 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil return nil } + +func GetQuestionWithKnowledge(knowledge string, question string) string { + return fmt.Sprintf(`paragraph: %s + +You are a reading comprehension expert. Please answer the following questions based on the provided content. The content may be in a different language from the questions, so you need to understand the content according to the language of the questions and ensure that your answers are translated into the same language as the questions: +Q1: %s`, knowledge, question) +} diff --git a/ai/embedding.go b/ai/embedding.go index 96f6045..abd101f 100644 --- a/ai/embedding.go +++ b/ai/embedding.go @@ -25,7 +25,7 @@ import ( ) func splitTxt(f io.ReadCloser) []string { - const maxLength = 512 * 3 + const maxLength = 210 * 3 scanner := bufio.NewScanner(f) var res []string var temp string @@ -51,14 +51,14 @@ func GetSplitTxt(f io.ReadCloser) []string { return splitTxt(f) } -func getEmbedding(authToken string, input []string, timeout int) ([]float32, error) { +func getEmbedding(authToken string, text string, timeout int) ([]float32, error) { client := getProxyClientFromToken(authToken) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(30+timeout*2)*time.Second) defer cancel() resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ - Input: input, + Input: []string{text}, Model: openai.AdaEmbeddingV2, }) if err != nil { @@ -68,11 +68,11 @@ func getEmbedding(authToken string, input []string, timeout int) ([]float32, err return resp.Data[0].Embedding, nil } -func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) { +func GetEmbeddingSafe(authToken string, text string) ([]float32, error) { var embedding []float32 var err error for i := 0; i < 10; i++ { - embedding, err = getEmbedding(authToken, input, i) + embedding, err = getEmbedding(authToken, text, i) if err != nil { if i > 0 { fmt.Printf("\tFailed (%d): %s\n", i+1, err.Error()) @@ -88,3 +88,18 @@ func GetEmbeddingSafe(authToken string, input []string) ([]float32, error) { return embedding, nil } } + +func GetNearestVectorIndex(target []float32, vectors [][]float32) int { + targetNorm := norm(target) + + var res int + max := float32(-1.0) + for i, vector := range vectors { + similarity := cosineSimilarity(target, vector, targetNorm) + if similarity > max { + max = similarity + res = i + } + } + return res +} diff --git a/ai/util.go b/ai/util.go index 5cd84cf..6f7eee2 100644 --- a/ai/util.go +++ b/ai/util.go @@ -14,7 +14,11 @@ package ai -import "github.com/pkoukk/tiktoken-go" +import ( + "math" + + "github.com/pkoukk/tiktoken-go" +) func GetTokenSize(model string, prompt string) (int, error) { tkm, err := tiktoken.EncodingForModel(model) @@ -26,3 +30,32 @@ func GetTokenSize(model string, prompt string) (int, error) { res := len(token) return res, nil } + +func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 { + dotProduct := dot(vec1, vec2) + vec2Norm := norm(vec2) + if vec2Norm == 0 { + return 0.0 + } + return dotProduct / (vec1Norm * vec2Norm) +} + +func dot(vec1, vec2 []float32) float32 { + if len(vec1) != len(vec2) { + panic("Vector lengths do not match") + } + + dotProduct := float32(0.0) + for i := range vec1 { + dotProduct += vec1[i] * vec2[i] + } + return dotProduct +} + +func norm(vec []float32) float32 { + normSquared := float32(0.0) + for _, val := range vec { + normSquared += val * val + } + return float32(math.Sqrt(float64(normSquared))) +} diff --git a/controllers/message.go b/controllers/message.go index ce188ac..0993941 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -149,10 +149,19 @@ func (c *ApiController) GetMessageAnswer() { question := questionMessage.Text var stringBuilder strings.Builder - fmt.Printf("Question: [%s]\n", questionMessage.Text) + nearestText, err := object.GetNearestVectorText(authToken, chat.Owner, question) + if err != nil { + c.ResponseErrorStream(err.Error()) + return + } + + realQuestion := ai.GetQuestionWithKnowledge(nearestText, question) + + fmt.Printf("Question: [%s]\n", question) + fmt.Printf("Context: [%s]\n", nearestText) fmt.Printf("Answer: [") - err = ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) + err = ai.QueryAnswerStream(authToken, realQuestion, c.Ctx.ResponseWriter, &stringBuilder) if err != nil { c.ResponseErrorStream(err.Error()) return diff --git a/object/vector_embedding.go b/object/vector_embedding.go index da675fa..72d964e 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -62,7 +62,7 @@ func getObjectReadCloser(object *storage.Object) (io.ReadCloser, error) { } func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { - embedding, err := ai.GetEmbeddingSafe(authToken, []string{text}) + embedding, err := ai.GetEmbeddingSafe(authToken, text) if err != nil { return false, err } @@ -131,3 +131,38 @@ func setTxtObjectVector(authToken string, provider string, key string, storeName return true, nil } + +func getRelatedVectors(owner string) ([]*Vector, error) { + vectors, err := GetVectors(owner) + if err != nil { + return nil, err + } + if len(vectors) == 0 { + return nil, fmt.Errorf("no knowledge vectors found") + } + + return vectors, nil +} + +func GetNearestVectorText(authToken string, owner string, question string) (string, error) { + qVector, err := ai.GetEmbeddingSafe(authToken, question) + if err != nil { + return "", err + } + if qVector == nil { + return "", fmt.Errorf("no qVector found") + } + + vectors, err := getRelatedVectors(owner) + if err != nil { + return "", err + } + + var nVectors [][]float32 + for _, candidate := range vectors { + nVectors = append(nVectors, candidate.Data) + } + + i := ai.GetNearestVectorIndex(qVector, nVectors) + return vectors[i].Text, nil +}