| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "context" | "context" | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "context" | "context" | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "context" | "context" | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "io" | "io" | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "context" | "context" | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "github.com/casbin/casibase/proxy" | "github.com/casbin/casibase/proxy" | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "context" | "context" | ||||
| @@ -15,12 +15,11 @@ | |||||
| //go:build !skipCi | //go:build !skipCi | ||||
| // +build !skipCi | // +build !skipCi | ||||
| package ai_test | |||||
| package model_test | |||||
| import ( | import ( | ||||
| "testing" | "testing" | ||||
| "github.com/casbin/casibase/ai" | |||||
| "github.com/casbin/casibase/object" | "github.com/casbin/casibase/object" | ||||
| "github.com/casbin/casibase/proxy" | "github.com/casbin/casibase/proxy" | ||||
| "github.com/sashabaranov/go-openai" | "github.com/sashabaranov/go-openai" | ||||
| @@ -30,10 +29,10 @@ func TestRun(t *testing.T) { | |||||
| object.InitConfig() | object.InitConfig() | ||||
| proxy.InitHttpClient() | proxy.InitHttpClient() | ||||
| text := ai.QueryAnswerSafe("", "hi") | |||||
| text := model.QueryAnswerSafe("", "hi") | |||||
| println(text) | println(text) | ||||
| } | } | ||||
| func TestToken(t *testing.T) { | func TestToken(t *testing.T) { | ||||
| println(ai.GetTokenSize(openai.GPT3TextDavinci003, "")) | |||||
| println(model.GetTokenSize(openai.GPT3TextDavinci003, "")) | |||||
| } | } | ||||
| @@ -12,7 +12,7 @@ | |||||
| // See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | // limitations under the License. | ||||
| package ai | |||||
| package model | |||||
| import ( | import ( | ||||
| "math" | "math" | ||||
| @@ -17,7 +17,7 @@ package object | |||||
| import ( | import ( | ||||
| "fmt" | "fmt" | ||||
| "github.com/casbin/casibase/ai" | |||||
| "github.com/casbin/casibase/model" | |||||
| "github.com/casbin/casibase/util" | "github.com/casbin/casibase/util" | ||||
| "xorm.io/core" | "xorm.io/core" | ||||
| ) | ) | ||||
| @@ -161,8 +161,8 @@ func (provider *Provider) GetId() string { | |||||
| return fmt.Sprintf("%s/%s", provider.Owner, provider.Name) | return fmt.Sprintf("%s/%s", provider.Owner, provider.Name) | ||||
| } | } | ||||
| func (p *Provider) GetModelProvider() (ai.ModelProvider, error) { | |||||
| pProvider, err := ai.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||||
| func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||||
| pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||||
| if err != nil { | if err != nil { | ||||
| return nil, err | return nil, err | ||||
| } | } | ||||
| @@ -20,7 +20,7 @@ import ( | |||||
| "path/filepath" | "path/filepath" | ||||
| "time" | "time" | ||||
| "github.com/casbin/casibase/ai" | |||||
| "github.com/casbin/casibase/model" | |||||
| "github.com/casbin/casibase/storage" | "github.com/casbin/casibase/storage" | ||||
| "github.com/casbin/casibase/txt" | "github.com/casbin/casibase/txt" | ||||
| "github.com/casbin/casibase/util" | "github.com/casbin/casibase/util" | ||||
| @@ -54,7 +54,7 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, | |||||
| } | } | ||||
| func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | ||||
| embedding, err := ai.GetEmbeddingSafe(authToken, text) | |||||
| embedding, err := model.GetEmbeddingSafe(authToken, text) | |||||
| if err != nil { | if err != nil { | ||||
| return false, err | return false, err | ||||
| } | } | ||||
| @@ -128,7 +128,7 @@ func getRelatedVectors(owner string) ([]*Vector, error) { | |||||
| } | } | ||||
| func GetNearestVectorText(authToken string, owner string, question string) (string, error) { | func GetNearestVectorText(authToken string, owner string, question string) (string, error) { | ||||
| qVector, err := ai.GetEmbeddingSafe(authToken, question) | |||||
| qVector, err := model.GetEmbeddingSafe(authToken, question) | |||||
| if err != nil { | if err != nil { | ||||
| return "", err | return "", err | ||||
| } | } | ||||
| @@ -146,6 +146,6 @@ func GetNearestVectorText(authToken string, owner string, question string) (stri | |||||
| nVectors = append(nVectors, candidate.Data) | nVectors = append(nVectors, candidate.Data) | ||||
| } | } | ||||
| i := ai.GetNearestVectorIndex(qVector, nVectors) | |||||
| i := model.GetNearestVectorIndex(qVector, nVectors) | |||||
| return vectors[i].Text, nil | return vectors[i].Text, nil | ||||
| } | } | ||||