diff --git a/go.mod b/go.mod index a4b3960..6693171 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/pretty v0.3.0 // indirect + github.com/madebywelch/anthropic-go v1.0.1 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index 1245a79..09d40af 100644 --- a/go.sum +++ b/go.sum @@ -388,6 +388,8 @@ github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-b github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lucasb-eyer/go-colorful v1.0.2/go.mod h1:0MS4r+7BZKSJ5mw4/S5MPN+qHFF1fYclkSPilDOKW0s= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= +github.com/madebywelch/anthropic-go v1.0.1 h1:LalIkikXbN53MIHGQekhwDWTs/x+/WHU71Ht1T4F0Ug= +github.com/madebywelch/anthropic-go v1.0.1/go.mod h1:ipU4SV1KHLcxo7lpR/O3JWq6985kaNcdYQgB46gaQCQ= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= diff --git a/model/claude.go b/model/claude.go new file mode 100644 index 0000000..e5cfb95 --- /dev/null +++ b/model/claude.go @@ -0,0 +1,63 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "fmt" + "io" + "net/http" + "strings" + + "github.com/madebywelch/anthropic-go/pkg/anthropic" +) + +type ClaudeModelProvider struct { + subType string + secretKey string +} + +func NewClaudeModelProvider(subType string, secretKey string) (*ClaudeModelProvider, error) { + return &ClaudeModelProvider{subType: subType, secretKey: secretKey}, nil +} + +func (p *ClaudeModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error { + client, err := anthropic.NewClient(p.secretKey) + if err != nil { + panic(err) + } + response, _ := client.Complete(&anthropic.CompletionRequest{ + Prompt: anthropic.GetPrompt(question), + Model: anthropic.Model(p.subType), + MaxTokensToSample: 100, + StopSequences: []string{"\r", "Human:"}, + }, nil) + flusher, ok := writer.(http.Flusher) + if !ok { + return fmt.Errorf("writer does not implement http.Flusher") + } + flushData := func(data string) error { + if _, err := fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil { + return err + } + flusher.Flush() + builder.WriteString(data) + return nil + } + err = flushData(response.Completion) + if err != nil { + return err + } + return nil +} diff --git a/model/provider.go b/model/provider.go index b27eeb4..4a9fde4 100644 --- a/model/provider.go +++ b/model/provider.go @@ -30,6 +30,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret p, err = NewOpenAiModelProvider(subType, clientSecret) } else if typ == "Hugging Face" { p, err = NewHuggingFaceModelProvider(subType, clientSecret) + } else if typ == "Claude" { + p, err = NewClaudeModelProvider(subType, clientSecret) } else if typ == "OpenRouter" { p, err = NewOpenRouterModelProvider(subType, clientSecret) } else if typ == "Ernie" { diff --git a/web/src/Setting.js b/web/src/Setting.js index 0bf1815..01fea4f 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -636,6 +636,7 @@ export function getProviderTypeOptions(category) { [ {id: "OpenAI", name: "OpenAI"}, {id: "Hugging Face", name: "Hugging Face"}, + {id: "Claude", name: "Claude"}, {id: "OpenRouter", name: "OpenRouter"}, {id: "Ernie", name: "Ernie"}, {id: "iFlytek", name: "iFlytek"}, @@ -731,6 +732,23 @@ export function getProviderSubTypeOptions(category, type) { } else { return []; } + } else if (type === "Claude") { + return ( + [ + {id: "claude-2", name: "claude-2"}, + {id: "claude-v1", name: "claude-v1"}, + {id: "claude-v1-100k", name: "claude-v1-100k"}, + {id: "claude-instant-v1", name: "claude-instant-v1"}, + {id: "claude-instant-v1-100k", name: "claude-instant-v1-100k"}, + {id: "claude-v1.3", name: "claude-v1.3"}, + {id: "claude-v1.3-100k", name: "claude-v1.3-100k"}, + {id: "claude-v1.2", name: "claude-v1.2"}, + {id: "claude-v1.0", name: "claude-v1.0"}, + {id: "claude-instant-v1.1", name: "claude-instant-v1.1"}, + {id: "claude-instant-v1.1-100k", name: "claude-instant-v1.1-100k"}, + {id: "claude-instant-v1.0", name: "claude-instant-v1.0"}, + ] + ); } else if (type === "OpenRouter") { return ( [