You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.go 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package embedding
  15. import (
  16. "context"
  17. "github.com/sashabaranov/go-openai"
  18. )
  19. type LocalEmbeddingProvider struct {
  20. subType string
  21. secretKey string
  22. providerUrl string
  23. }
  24. func NewLocalEmbeddingProvider(subType string, secretKey string, providerUrl string) (*LocalEmbeddingProvider, error) {
  25. p := &LocalEmbeddingProvider{
  26. subType: subType,
  27. secretKey: secretKey,
  28. providerUrl: providerUrl,
  29. }
  30. return p, nil
  31. }
  32. func getLocalClientFromUrl(authToken string, url string) *openai.Client {
  33. config := openai.DefaultConfig(authToken)
  34. config.BaseURL = url
  35. c := openai.NewClientWithConfig(config)
  36. return c
  37. }
  38. func (p *LocalEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, error) {
  39. client := getLocalClientFromUrl(p.secretKey, p.providerUrl)
  40. resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
  41. Input: []string{text},
  42. Model: openai.EmbeddingModel(1),
  43. })
  44. if err != nil {
  45. return nil, err
  46. }
  47. return resp.Data[0].Embedding, nil
  48. }