@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"context" | "context" |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"context" | "context" |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"context" | "context" |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"io" | "io" |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"context" | "context" |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"github.com/casbin/casibase/proxy" | "github.com/casbin/casibase/proxy" |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"context" | "context" |
@@ -15,12 +15,11 @@ | |||||
//go:build !skipCi | //go:build !skipCi | ||||
// +build !skipCi | // +build !skipCi | ||||
package ai_test | |||||
package model_test | |||||
import ( | import ( | ||||
"testing" | "testing" | ||||
"github.com/casbin/casibase/ai" | |||||
"github.com/casbin/casibase/object" | "github.com/casbin/casibase/object" | ||||
"github.com/casbin/casibase/proxy" | "github.com/casbin/casibase/proxy" | ||||
"github.com/sashabaranov/go-openai" | "github.com/sashabaranov/go-openai" | ||||
@@ -30,10 +29,10 @@ func TestRun(t *testing.T) { | |||||
object.InitConfig() | object.InitConfig() | ||||
proxy.InitHttpClient() | proxy.InitHttpClient() | ||||
text := ai.QueryAnswerSafe("", "hi") | |||||
text := model.QueryAnswerSafe("", "hi") | |||||
println(text) | println(text) | ||||
} | } | ||||
func TestToken(t *testing.T) { | func TestToken(t *testing.T) { | ||||
println(ai.GetTokenSize(openai.GPT3TextDavinci003, "")) | |||||
println(model.GetTokenSize(openai.GPT3TextDavinci003, "")) | |||||
} | } |
@@ -12,7 +12,7 @@ | |||||
// See the License for the specific language governing permissions and | // See the License for the specific language governing permissions and | ||||
// limitations under the License. | // limitations under the License. | ||||
package ai | |||||
package model | |||||
import ( | import ( | ||||
"math" | "math" |
@@ -17,7 +17,7 @@ package object | |||||
import ( | import ( | ||||
"fmt" | "fmt" | ||||
"github.com/casbin/casibase/ai" | |||||
"github.com/casbin/casibase/model" | |||||
"github.com/casbin/casibase/util" | "github.com/casbin/casibase/util" | ||||
"xorm.io/core" | "xorm.io/core" | ||||
) | ) | ||||
@@ -161,8 +161,8 @@ func (provider *Provider) GetId() string { | |||||
return fmt.Sprintf("%s/%s", provider.Owner, provider.Name) | return fmt.Sprintf("%s/%s", provider.Owner, provider.Name) | ||||
} | } | ||||
func (p *Provider) GetModelProvider() (ai.ModelProvider, error) { | |||||
pProvider, err := ai.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||||
func (p *Provider) GetModelProvider() (model.ModelProvider, error) { | |||||
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -20,7 +20,7 @@ import ( | |||||
"path/filepath" | "path/filepath" | ||||
"time" | "time" | ||||
"github.com/casbin/casibase/ai" | |||||
"github.com/casbin/casibase/model" | |||||
"github.com/casbin/casibase/storage" | "github.com/casbin/casibase/storage" | ||||
"github.com/casbin/casibase/txt" | "github.com/casbin/casibase/txt" | ||||
"github.com/casbin/casibase/util" | "github.com/casbin/casibase/util" | ||||
@@ -54,7 +54,7 @@ func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, | |||||
} | } | ||||
func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) { | ||||
embedding, err := ai.GetEmbeddingSafe(authToken, text) | |||||
embedding, err := model.GetEmbeddingSafe(authToken, text) | |||||
if err != nil { | if err != nil { | ||||
return false, err | return false, err | ||||
} | } | ||||
@@ -128,7 +128,7 @@ func getRelatedVectors(owner string) ([]*Vector, error) { | |||||
} | } | ||||
func GetNearestVectorText(authToken string, owner string, question string) (string, error) { | func GetNearestVectorText(authToken string, owner string, question string) (string, error) { | ||||
qVector, err := ai.GetEmbeddingSafe(authToken, question) | |||||
qVector, err := model.GetEmbeddingSafe(authToken, question) | |||||
if err != nil { | if err != nil { | ||||
return "", err | return "", err | ||||
} | } | ||||
@@ -146,6 +146,6 @@ func GetNearestVectorText(authToken string, owner string, question string) (stri | |||||
nVectors = append(nVectors, candidate.Data) | nVectors = append(nVectors, candidate.Data) | ||||
} | } | ||||
i := ai.GetNearestVectorIndex(qVector, nVectors) | |||||
i := model.GetNearestVectorIndex(qVector, nVectors) | |||||
return vectors[i].Text, nil | return vectors[i].Text, nil | ||||
} | } |