From 59acb72a5990e55a5f003da5fe9e67ca3a2a09e5 Mon Sep 17 00:00:00 2001 From: Kelvin Chiu Date: Thu, 14 Sep 2023 15:04:10 +0800 Subject: [PATCH] feat: support iFlytek spark model (#635) --- go.mod | 4 +++ go.sum | 9 ++++++ model/iflytek.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++ model/provider.go | 2 ++ web/src/Setting.js | 8 +++++ 5 files changed, 102 insertions(+) create mode 100644 model/iflytek.go diff --git a/go.mod b/go.mod index 6bb96a5..a4b3960 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/google/go-cmp v0.5.9 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/henomis/cohere-go v1.0.1 // indirect github.com/henomis/restclientgo v1.0.5 // indirect @@ -64,6 +65,9 @@ require ( github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 // indirect github.com/syndtr/goleveldb v1.0.0 // indirect + github.com/vogo/gorun v1.1.0 // indirect + github.com/vogo/logger v1.5.1 // indirect + github.com/vogo/xfspark v0.1.2 // indirect golang.org/x/crypto v0.7.0 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/mod v0.8.0 // indirect diff --git a/go.sum b/go.sum index b0c225c..1245a79 100644 --- a/go.sum +++ b/go.sum @@ -253,6 +253,8 @@ github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51 github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= @@ -579,6 +581,12 @@ github.com/unidoc/unioffice v1.4.0 h1:yl+TbZJu2GTVYAYvu51wppj0R+fPC67xzVcy91qgrz github.com/unidoc/unioffice v1.4.0/go.mod h1:7wl8btOkZW1TfqfpDWoujRXkUpowwisGRYDo7COHBiI= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/vogo/gorun v1.1.0 h1:i/8HhmNyjMa79bFn1ZHGTO3KyiWQcZPBsFGpvgkcap8= +github.com/vogo/gorun v1.1.0/go.mod h1:MyOjF/DbZSz40GQ5657Ou9NEfNlUbMhvpXjDvh4L8OY= +github.com/vogo/logger v1.5.1 h1:voyVY69TpM/x1lml4LYy4jMe5z0kDh5jW1oatcikajM= +github.com/vogo/logger v1.5.1/go.mod h1:9U+qupncHpWpt4ptlxFm9bvDb9EjbGIA+cY/tYtW4Kg= +github.com/vogo/xfspark v0.1.2 h1:QdW7jvFL6bgHQ74xTuAlP2IBj4sIVCpBE/LS0LW4ePE= +github.com/vogo/xfspark v0.1.2/go.mod h1:KHvAxw98fpLsFdmFYm9ps99eGaLbuLBTBa9rVG6SFyo= github.com/wcharczuk/go-chart/v2 v2.1.0/go.mod h1:yx7MvAVNcP/kN9lKXM/NTce4au4DFN99j6i1OwDclNA= github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= @@ -1037,6 +1045,7 @@ modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= modernc.org/z v1.5.1 h1:RTNHdsrOpeoSeOF4FbzTo8gBYByaJ5xT7NgZ9ZqRiJM= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= diff --git a/model/iflytek.go b/model/iflytek.go new file mode 100644 index 0000000..2783565 --- /dev/null +++ b/model/iflytek.go @@ -0,0 +1,79 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "fmt" + "io" + "net/http" + "strings" + + iflytek "github.com/vogo/xfspark/chat" +) + +type iFlytekModelProvider struct { + subType string + appID string + apiKey string + secretKey string +} + +func NewiFlytekModelProvider(subType string, secretKey string) (*iFlytekModelProvider, error) { + p := &iFlytekModelProvider{ + subType: subType, + appID: "", + apiKey: "", + secretKey: secretKey, + } + return p, nil +} + +func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { + client := iflytek.NewServer(p.appID, p.apiKey, p.secretKey) + flusher, ok := writer.(http.Flusher) + if !ok { + return fmt.Errorf("writer does not implement http.Flusher") + } + + session, err := client.GetSession("1") + if err != nil { + return fmt.Errorf("iflytek get session error: %v", err) + } + if session == nil { + return fmt.Errorf("iflytek get session error: session is nil") + } + + response, err := session.Send(question) + if err != nil { + return fmt.Errorf("iflytek send error: %v", err) + } + + flushData := func(data string) error { + if _, err := fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil { + return err + } + flusher.Flush() + builder.WriteString(data) + return nil + } + + err = flushData(response) + + if builder != nil { + builder.WriteString(response) + } + + return nil +} diff --git a/model/provider.go b/model/provider.go index c5a3615..b27eeb4 100644 --- a/model/provider.go +++ b/model/provider.go @@ -34,6 +34,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret p, err = NewOpenRouterModelProvider(subType, clientSecret) } else if typ == "Ernie" { p, err = NewErnieModelProvider(subType, clientId, clientSecret) + } else if typ == "iFlytek" { + p, err = NewiFlytekModelProvider(subType, clientSecret) } if err != nil { diff --git a/web/src/Setting.js b/web/src/Setting.js index 94e9421..ba006a9 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -632,6 +632,7 @@ export function getProviderTypeOptions(category) { {id: "Hugging Face", name: "Hugging Face"}, {id: "OpenRouter", name: "OpenRouter"}, {id: "Ernie", name: "Ernie"}, + {id: "iFlytek", name: "iFlytek"}, ] ); } else if (category === "Embedding") { @@ -776,6 +777,13 @@ export function getProviderSubTypeOptions(category, type) { {id: "embed-multilingual-v2.0", name: "embed-multilingual-v2.0"}, ] ); + } else if (type === "iFlytek") { + return ( + [ + {id: "spark-v1.5", name: "spark-v1.5"}, + {id: "spark-v2.0", name: "spark-v2.0"}, + ] + ); } else { return []; }