Browse Source

Rename pkg to model

HEAD
Yang Luo 2 years ago
parent
commit
2ab7e3011b
11 changed files with 18 additions and 19 deletions
  1. +1
    -1
      model/embedding.go
  2. +1
    -1
      model/ernie.go
  3. +1
    -1
      model/huggingface.go
  4. +1
    -1
      model/model.go
  5. +1
    -1
      model/openai.go
  6. +1
    -1
      model/openai_proxy.go
  7. +1
    -1
      model/query.go
  8. +3
    -4
      model/query_test.go
  9. +1
    -1
      model/util.go
  10. +3
    -3
      object/provider.go
  11. +4
    -4
      object/vector_embedding.go

ai/embedding.go → model/embedding.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"context" "context"

ai/ernie.go → model/ernie.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"context" "context"

ai/huggingface.go → model/huggingface.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"context" "context"

ai/model.go → model/model.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"io" "io"

ai/openai.go → model/openai.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"context" "context"

ai/openai_proxy.go → model/openai_proxy.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"github.com/casbin/casibase/proxy" "github.com/casbin/casibase/proxy"

ai/query.go → model/query.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"context" "context"

ai/query_test.go → model/query_test.go View File

@@ -15,12 +15,11 @@
//go:build !skipCi //go:build !skipCi
// +build !skipCi // +build !skipCi


package ai_test
package model_test


import ( import (
"testing" "testing"


"github.com/casbin/casibase/ai"
"github.com/casbin/casibase/object" "github.com/casbin/casibase/object"
"github.com/casbin/casibase/proxy" "github.com/casbin/casibase/proxy"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
@@ -30,10 +29,10 @@ func TestRun(t *testing.T) {
object.InitConfig() object.InitConfig()
proxy.InitHttpClient() proxy.InitHttpClient()


text := ai.QueryAnswerSafe("", "hi")
text := model.QueryAnswerSafe("", "hi")
println(text) println(text)
} }


func TestToken(t *testing.T) { func TestToken(t *testing.T) {
println(ai.GetTokenSize(openai.GPT3TextDavinci003, ""))
println(model.GetTokenSize(openai.GPT3TextDavinci003, ""))
} }

ai/util.go → model/util.go View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.


package ai
package model


import ( import (
"math" "math"

+ 3
- 3
object/provider.go View File

@@ -17,7 +17,7 @@ package object
import ( import (
"fmt" "fmt"


"github.com/casbin/casibase/ai"
"github.com/casbin/casibase/model"
"github.com/casbin/casibase/util" "github.com/casbin/casibase/util"
"xorm.io/core" "xorm.io/core"
) )
@@ -161,8 +161,8 @@ func (provider *Provider) GetId() string {
return fmt.Sprintf("%s/%s", provider.Owner, provider.Name) return fmt.Sprintf("%s/%s", provider.Owner, provider.Name)
} }


func (p *Provider) GetModelProvider() (ai.ModelProvider, error) {
pProvider, err := ai.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret)
func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }


+ 4
- 4
object/vector_embedding.go View File

@@ -20,7 +20,7 @@ import (
"path/filepath" "path/filepath"
"time" "time"


"github.com/casbin/casibase/ai"
"github.com/casbin/casibase/model"
"github.com/casbin/casibase/storage" "github.com/casbin/casibase/storage"
"github.com/casbin/casibase/txt" "github.com/casbin/casibase/txt"
"github.com/casbin/casibase/util" "github.com/casbin/casibase/util"
@@ -54,7 +54,7 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object,
} }


func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
embedding, err := ai.GetEmbeddingSafe(authToken, text)
embedding, err := model.GetEmbeddingSafe(authToken, text)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -128,7 +128,7 @@ func getRelatedVectors(owner string) ([]*Vector, error) {
} }


func GetNearestVectorText(authToken string, owner string, question string) (string, error) { func GetNearestVectorText(authToken string, owner string, question string) (string, error) {
qVector, err := ai.GetEmbeddingSafe(authToken, question)
qVector, err := model.GetEmbeddingSafe(authToken, question)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -146,6 +146,6 @@ func GetNearestVectorText(authToken string, owner string, question string) (stri
nVectors = append(nVectors, candidate.Data) nVectors = append(nVectors, candidate.Data)
} }


i := ai.GetNearestVectorIndex(qVector, nVectors)
i := model.GetNearestVectorIndex(qVector, nVectors)
return vectors[i].Text, nil return vectors[i].Text, nil
} }

Loading…
Cancel
Save