Co-authored-by: techknowlogick <techknowlogick@gitea.io>tags/v1.13.0-rc1
@@ -120,5 +120,5 @@ require ( | |||
mvdan.cc/xurls/v2 v2.1.0 | |||
strk.kbt.io/projects/go/libravatar v0.0.0-20191008002943-06d1c002b251 | |||
xorm.io/builder v0.3.7 | |||
xorm.io/xorm v1.0.1 | |||
xorm.io/xorm v1.0.2 | |||
) |
@@ -920,5 +920,5 @@ xorm.io/core v0.7.2 h1:mEO22A2Z7a3fPaZMk6gKL/jMD80iiyNwRrX5HOv3XLw= | |||
xorm.io/core v0.7.2/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM= | |||
xorm.io/xorm v0.8.0 h1:iALxgJrX8O00f8Jk22GbZwPmxJNgssV5Mv4uc2HL9PM= | |||
xorm.io/xorm v0.8.0/go.mod h1:ZkJLEYLoVyg7amJK/5r779bHyzs2AU8f8VMiP6BM7uY= | |||
xorm.io/xorm v1.0.1 h1:/lITxpJtkZauNpdzj+L9CN/3OQxZaABrbergMcJu+Cw= | |||
xorm.io/xorm v1.0.1/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY= | |||
xorm.io/xorm v1.0.2 h1:kZlCh9rqd1AzGwWitcrEEqHE1h1eaZE/ujU5/2tWEtg= | |||
xorm.io/xorm v1.0.2/go.mod h1:o4vnEsQ5V2F1/WK6w4XTwmiWJeGj82tqjAnHe44wVHY= |
@@ -23,7 +23,7 @@ func DefaultDBContext() DBContext { | |||
// Committer represents an interface to Commit or Close the dbcontext | |||
type Committer interface { | |||
Commit() error | |||
Close() | |||
Close() error | |||
} | |||
// TxDBContext represents a transaction DBContext | |||
@@ -878,7 +878,7 @@ strk.kbt.io/projects/go/libravatar | |||
# xorm.io/builder v0.3.7 | |||
## explicit | |||
xorm.io/builder | |||
# xorm.io/xorm v1.0.1 | |||
# xorm.io/xorm v1.0.2 | |||
## explicit | |||
xorm.io/xorm | |||
xorm.io/xorm/caches | |||
@@ -3,12 +3,13 @@ kind: pipeline | |||
name: testing | |||
steps: | |||
- name: test-vet | |||
image: golang:1.11 | |||
image: golang:1.11 # The lowest golang requirement | |||
environment: | |||
GO111MODULE: "on" | |||
GOPROXY: "https://goproxy.cn" | |||
commands: | |||
- go vet | |||
- make vet | |||
- make test | |||
when: | |||
event: | |||
- push | |||
@@ -23,10 +24,6 @@ steps: | |||
- make test-sqlite | |||
- TEST_CACHE_ENABLE=true make test-sqlite | |||
- TEST_QUOTE_POLICY=reserved make test-sqlite | |||
- go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \ | |||
./log/... ./migrate/... ./names/... ./schemas/... ./tags/... \ | |||
./internal/json/... ./internal/statements/... ./internal/utils/... \ | |||
when: | |||
event: | |||
- push | |||
@@ -34,3 +34,5 @@ test.db.sql | |||
.idea/ | |||
*coverage.out | |||
test.db | |||
integrations/*.sql |
@@ -7,8 +7,8 @@ TAGS ?= | |||
SED_INPLACE := sed -i | |||
GOFILES := $(shell find . -name "*.go" -type f) | |||
PACKAGES ?= $(shell GO111MODULE=on $(GO) list ./...) | |||
INTEGRATION_PACKAGES := xorm.io/xorm/integrations | |||
PACKAGES ?= $(filter-out $(INTEGRATION_PACKAGES),$(shell $(GO) list ./...)) | |||
TEST_COCKROACH_HOST ?= cockroach:26257 | |||
TEST_COCKROACH_SCHEMA ?= | |||
@@ -46,12 +46,12 @@ all: build | |||
.PHONY: build | |||
build: go-check $(GO_SOURCES) | |||
$(GO) build | |||
$(GO) build $(PACKAGES) | |||
.PHONY: clean | |||
clean: | |||
$(GO) clean -i ./... | |||
rm -rf *.sql *.log test.db *coverage.out coverage.all | |||
rm -rf *.sql *.log test.db *coverage.out coverage.all integrations/*.sql | |||
.PHONY: coverage | |||
coverage: | |||
@@ -92,7 +92,12 @@ help: | |||
@echo " - lint run code linter revive" | |||
@echo " - misspell check if a word is written wrong" | |||
@echo " - test run default unit test" | |||
@echo " - test-sqlite run unit test for sqlite" | |||
@echo " - test-cockroach run integration tests for cockroach" | |||
@echo " - test-mysql run integration tests for mysql" | |||
@echo " - test-mssql run integration tests for mssql" | |||
@echo " - test-postgres run integration tests for postgres" | |||
@echo " - test-sqlite run integration tests for sqlite" | |||
@echo " - test-tidb run integration tests for tidb" | |||
@echo " - vet examines Go source code and reports suspicious constructs" | |||
.PHONY: lint | |||
@@ -120,95 +125,96 @@ misspell-check: | |||
misspell -error -i unknwon,destory $(GOFILES) | |||
.PHONY: test | |||
test: test-sqlite | |||
test: go-check | |||
$(GO) test $(PACKAGES) | |||
.PNONY: test-cockroach | |||
test-cockroach: go-check | |||
$(GO) test -race -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
-conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \ | |||
-ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-cockroach\#% | |||
test-cockroach\#%: go-check | |||
$(GO) test -race -run $* -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
-conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \ | |||
-ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-mssql | |||
test-mssql: go-check | |||
$(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
-conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ | |||
-coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-mssql\#% | |||
test-mssql\#%: go-check | |||
$(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
-conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ | |||
-coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-mymysql | |||
test-mymysql: go-check | |||
$(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
-conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ | |||
-coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-mymysql\#% | |||
test-mymysql\#%: go-check | |||
$(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
-conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ | |||
-coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-mysql | |||
test-mysql: go-check | |||
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ | |||
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-mysql\#% | |||
test-mysql\#%: go-check | |||
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ | |||
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ | |||
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-postgres | |||
test-postgres: go-check | |||
$(GO) test -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-postgres\#% | |||
test-postgres\#%: go-check | |||
$(GO) test -v -race -run $* -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ | |||
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-sqlite | |||
test-sqlite: go-check | |||
$(GO) test -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-sqlite-schema | |||
test-sqlite-schema: go-check | |||
$(GO) test -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-sqlite\#% | |||
test-sqlite\#%: go-check | |||
$(GO) test -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PNONY: test-tidb | |||
test-tidb: go-check | |||
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ | |||
-conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: test-tidb\#% | |||
test-tidb\#%: go-check | |||
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ | |||
$(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ | |||
-conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ | |||
-quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic | |||
.PHONY: vet | |||
vet: | |||
$(GO) vet $(PACKAGES) | |||
$(GO) vet $(shell $(GO) list ./...) |
@@ -67,6 +67,8 @@ Drivers for Go's sql package which currently support database/sql includes: | |||
* Create Engine | |||
Firstly, we should new an engine for a database. | |||
```Go | |||
engine, err := xorm.NewEngine(driverName, dataSourceName) | |||
``` | |||
@@ -418,7 +420,7 @@ res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) | |||
## Contributing | |||
If you want to pull request, please see [CONTRIBUTING](https://gitea.com/xorm/xorm/src/branch/master/CONTRIBUTING.md). And we also provide [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) to discuss. | |||
If you want to pull request, please see [CONTRIBUTING](https://gitea.com/xorm/xorm/src/branch/master/CONTRIBUTING.md). And you can also go to [Xorm on discourse](https://xorm.discourse.group) to discuss. | |||
## Credits | |||
@@ -0,0 +1,75 @@ | |||
// Copyright 2020 The Xorm Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package contexts | |||
import ( | |||
"context" | |||
"database/sql" | |||
"time" | |||
) | |||
// ContextHook represents a hook context | |||
type ContextHook struct { | |||
start time.Time | |||
Ctx context.Context | |||
SQL string // log content or SQL | |||
Args []interface{} // if it's a SQL, it's the arguments | |||
Result sql.Result | |||
ExecuteTime time.Duration | |||
Err error // SQL executed error | |||
} | |||
// NewContextHook return context for hook | |||
func NewContextHook(ctx context.Context, sql string, args []interface{}) *ContextHook { | |||
return &ContextHook{ | |||
start: time.Now(), | |||
Ctx: ctx, | |||
SQL: sql, | |||
Args: args, | |||
} | |||
} | |||
func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) { | |||
c.Ctx = ctx | |||
c.Result = result | |||
c.Err = err | |||
c.ExecuteTime = time.Now().Sub(c.start) | |||
} | |||
type Hook interface { | |||
BeforeProcess(c *ContextHook) (context.Context, error) | |||
AfterProcess(c *ContextHook) error | |||
} | |||
type Hooks struct { | |||
hooks []Hook | |||
} | |||
func (h *Hooks) AddHook(hooks ...Hook) { | |||
h.hooks = append(h.hooks, hooks...) | |||
} | |||
func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) { | |||
ctx := c.Ctx | |||
for _, h := range h.hooks { | |||
var err error | |||
ctx, err = h.BeforeProcess(c) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
return ctx, nil | |||
} | |||
func (h *Hooks) AfterProcess(c *ContextHook) error { | |||
firstErr := c.Err | |||
for _, h := range h.hooks { | |||
err := h.AfterProcess(c) | |||
if err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
} | |||
return firstErr | |||
} |
@@ -25,11 +25,10 @@ func strconvErr(err error) error { | |||
func cloneBytes(b []byte) []byte { | |||
if b == nil { | |||
return nil | |||
} else { | |||
c := make([]byte, len(b)) | |||
copy(c, b) | |||
return c | |||
} | |||
c := make([]byte, len(b)) | |||
copy(c, b) | |||
return c | |||
} | |||
func asString(src interface{}) string { | |||
@@ -285,56 +284,6 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { | |||
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) | |||
} | |||
func convertFloat(v interface{}) (float64, error) { | |||
switch v.(type) { | |||
case float32: | |||
return float64(v.(float32)), nil | |||
case float64: | |||
return v.(float64), nil | |||
case string: | |||
i, err := strconv.ParseFloat(v.(string), 64) | |||
if err != nil { | |||
return 0, err | |||
} | |||
return i, nil | |||
case []byte: | |||
i, err := strconv.ParseFloat(string(v.([]byte)), 64) | |||
if err != nil { | |||
return 0, err | |||
} | |||
return i, nil | |||
} | |||
return 0, fmt.Errorf("unsupported type: %v", v) | |||
} | |||
func convertInt(v interface{}) (int64, error) { | |||
switch v.(type) { | |||
case int: | |||
return int64(v.(int)), nil | |||
case int8: | |||
return int64(v.(int8)), nil | |||
case int16: | |||
return int64(v.(int16)), nil | |||
case int32: | |||
return int64(v.(int32)), nil | |||
case int64: | |||
return v.(int64), nil | |||
case []byte: | |||
i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) | |||
if err != nil { | |||
return 0, err | |||
} | |||
return i, nil | |||
case string: | |||
i, err := strconv.ParseInt(v.(string), 10, 64) | |||
if err != nil { | |||
return 0, err | |||
} | |||
return i, nil | |||
} | |||
return 0, fmt.Errorf("unsupported type: %v", v) | |||
} | |||
func asBool(bs []byte) (bool, error) { | |||
if len(bs) == 0 { | |||
return false, nil | |||
@@ -12,8 +12,8 @@ import ( | |||
"reflect" | |||
"regexp" | |||
"sync" | |||
"time" | |||
"xorm.io/xorm/contexts" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/names" | |||
) | |||
@@ -88,6 +88,7 @@ type DB struct { | |||
reflectCache map[reflect.Type]*cacheStruct | |||
reflectCacheMutex sync.RWMutex | |||
Logger log.ContextLogger | |||
hooks contexts.Hooks | |||
} | |||
// Open opens a database | |||
@@ -118,7 +119,7 @@ func (db *DB) NeedLogSQL(ctx context.Context) bool { | |||
return false | |||
} | |||
v := ctx.Value("__xorm_show_sql") | |||
v := ctx.Value(log.SessionShowSQLKey) | |||
if showSQL, ok := v.(bool); ok { | |||
return showSQL | |||
} | |||
@@ -140,26 +141,14 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { | |||
// QueryContext overwrites sql.DB.QueryContext | |||
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { | |||
start := time.Now() | |||
showSQL := db.NeedLogSQL(ctx) | |||
if showSQL { | |||
db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, query, args) | |||
ctx, err := db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
rows, err := db.DB.QueryContext(ctx, query, args...) | |||
if showSQL { | |||
db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
} | |||
if err != nil { | |||
hookCtx.End(ctx, nil, err) | |||
if err := db.afterProcess(hookCtx); err != nil { | |||
if rows != nil { | |||
rows.Close() | |||
} | |||
@@ -239,7 +228,7 @@ var ( | |||
re = regexp.MustCompile(`[?](\w+)`) | |||
) | |||
// ExecMapContext exec map with context.Context | |||
// ExecMapContext exec map with context.ContextHook | |||
// insert into (name) values (?) | |||
// insert into (name) values (?name) | |||
func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { | |||
@@ -263,28 +252,42 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ | |||
} | |||
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | |||
start := time.Now() | |||
showSQL := db.NeedLogSQL(ctx) | |||
if showSQL { | |||
db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, query, args) | |||
ctx, err := db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
res, err := db.DB.ExecContext(ctx, query, args...) | |||
if showSQL { | |||
db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
hookCtx.End(ctx, res, err) | |||
if err := db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return res, err | |||
return res, nil | |||
} | |||
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { | |||
return db.ExecStructContext(context.Background(), query, st) | |||
} | |||
func (db *DB) beforeProcess(c *contexts.ContextHook) (context.Context, error) { | |||
if db.NeedLogSQL(c.Ctx) { | |||
db.Logger.BeforeSQL(log.LogContext(*c)) | |||
} | |||
ctx, err := db.hooks.BeforeProcess(c) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return ctx, nil | |||
} | |||
func (db *DB) afterProcess(c *contexts.ContextHook) error { | |||
err := db.hooks.AfterProcess(c) | |||
if db.NeedLogSQL(c.Ctx) { | |||
db.Logger.AfterSQL(log.LogContext(*c)) | |||
} | |||
return err | |||
} | |||
func (db *DB) AddHook(h ...contexts.Hook) { | |||
db.hooks.AddHook(h...) | |||
} |
@@ -9,9 +9,8 @@ import ( | |||
"database/sql" | |||
"errors" | |||
"reflect" | |||
"time" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/contexts" | |||
) | |||
// Stmt reprents a stmt objects | |||
@@ -30,28 +29,16 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { | |||
i++ | |||
return "?" | |||
}) | |||
start := time.Now() | |||
showSQL := db.NeedLogSQL(ctx) | |||
if showSQL { | |||
db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: "PREPARE", | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) | |||
ctx, err := db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
stmt, err := db.DB.PrepareContext(ctx, query) | |||
if showSQL { | |||
db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: "PREPARE", | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
} | |||
if err != nil { | |||
hookCtx.End(ctx, nil, err) | |||
if err := db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return &Stmt{stmt, db, names, query}, nil | |||
} | |||
@@ -94,49 +81,28 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { | |||
} | |||
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { | |||
start := time.Now() | |||
showSQL := s.db.NeedLogSQL(ctx) | |||
if showSQL { | |||
s.db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: s.query, | |||
Args: args, | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, s.query, args) | |||
ctx, err := s.db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
res, err := s.Stmt.ExecContext(ctx, args) | |||
if showSQL { | |||
s.db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: s.query, | |||
Args: args, | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
hookCtx.End(ctx, res, err) | |||
if err := s.db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return res, err | |||
return res, nil | |||
} | |||
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { | |||
start := time.Now() | |||
showSQL := s.db.NeedLogSQL(ctx) | |||
if showSQL { | |||
s.db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: s.query, | |||
Args: args, | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, s.query, args) | |||
ctx, err := s.db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
rows, err := s.Stmt.QueryContext(ctx, args...) | |||
if showSQL { | |||
s.db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: s.query, | |||
Args: args, | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
} | |||
if err != nil { | |||
hookCtx.End(ctx, nil, err) | |||
if err := s.db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return &Rows{rows, s.db}, nil | |||
@@ -175,7 +141,7 @@ func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, e | |||
args[i] = vv.Elem().FieldByName(k).Interface() | |||
} | |||
return s.Query(args...) | |||
return s.QueryContext(ctx, args...) | |||
} | |||
func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { | |||
@@ -7,9 +7,8 @@ package core | |||
import ( | |||
"context" | |||
"database/sql" | |||
"time" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/contexts" | |||
) | |||
var ( | |||
@@ -23,24 +22,14 @@ type Tx struct { | |||
} | |||
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { | |||
start := time.Now() | |||
showSQL := db.NeedLogSQL(ctx) | |||
if showSQL { | |||
db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: "BEGIN TRANSACTION", | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) | |||
ctx, err := db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
tx, err := db.DB.BeginTx(ctx, opts) | |||
if showSQL { | |||
db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: "BEGIN TRANSACTION", | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
} | |||
if err != nil { | |||
hookCtx.End(ctx, nil, err) | |||
if err := db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return &Tx{tx, db}, nil | |||
@@ -58,25 +47,14 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { | |||
i++ | |||
return "?" | |||
}) | |||
start := time.Now() | |||
showSQL := tx.db.NeedLogSQL(ctx) | |||
if showSQL { | |||
tx.db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: "PREPARE", | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) | |||
ctx, err := tx.db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
stmt, err := tx.Tx.PrepareContext(ctx, query) | |||
if showSQL { | |||
tx.db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: "PREPARE", | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
} | |||
if err != nil { | |||
hookCtx.End(ctx, nil, err) | |||
if err := tx.db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return &Stmt{stmt, tx.db, names, query}, nil | |||
@@ -116,24 +94,15 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ | |||
} | |||
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | |||
start := time.Now() | |||
showSQL := tx.db.NeedLogSQL(ctx) | |||
if showSQL { | |||
tx.db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, query, args) | |||
ctx, err := tx.db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
res, err := tx.Tx.ExecContext(ctx, query, args...) | |||
if showSQL { | |||
tx.db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
hookCtx.End(ctx, res, err) | |||
if err := tx.db.afterProcess(hookCtx); err != nil { | |||
return nil, err | |||
} | |||
return res, err | |||
} | |||
@@ -143,26 +112,14 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { | |||
} | |||
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { | |||
start := time.Now() | |||
showSQL := tx.db.NeedLogSQL(ctx) | |||
if showSQL { | |||
tx.db.Logger.BeforeSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
}) | |||
hookCtx := contexts.NewContextHook(ctx, query, args) | |||
ctx, err := tx.db.beforeProcess(hookCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
rows, err := tx.Tx.QueryContext(ctx, query, args...) | |||
if showSQL { | |||
tx.db.Logger.AfterSQL(log.LogContext{ | |||
Ctx: ctx, | |||
SQL: query, | |||
Args: args, | |||
ExecuteTime: time.Now().Sub(start), | |||
Err: err, | |||
}) | |||
} | |||
if err != nil { | |||
hookCtx.End(ctx, nil, err) | |||
if err := tx.db.afterProcess(hookCtx); err != nil { | |||
if rows != nil { | |||
rows.Close() | |||
} | |||
@@ -96,51 +96,6 @@ func (b *Base) DBType() schemas.DBType { | |||
return b.uri.DBType | |||
} | |||
// String generate column description string according dialect | |||
func (b *Base) String(col *schemas.Column) string { | |||
sql := b.dialect.Quoter().Quote(col.Name) + " " | |||
sql += b.dialect.SQLType(col) + " " | |||
if col.IsPrimaryKey { | |||
sql += "PRIMARY KEY " | |||
if col.IsAutoIncrement { | |||
sql += b.dialect.AutoIncrStr() + " " | |||
} | |||
} | |||
if col.Default != "" { | |||
sql += "DEFAULT " + col.Default + " " | |||
} | |||
if col.Nullable { | |||
sql += "NULL " | |||
} else { | |||
sql += "NOT NULL " | |||
} | |||
return sql | |||
} | |||
// StringNoPk generate column description string according dialect without primary keys | |||
func (b *Base) StringNoPk(col *schemas.Column) string { | |||
sql := b.dialect.Quoter().Quote(col.Name) + " " | |||
sql += b.dialect.SQLType(col) + " " | |||
if col.Default != "" { | |||
sql += "DEFAULT " + col.Default + " " | |||
} | |||
if col.Nullable { | |||
sql += "NULL " | |||
} else { | |||
sql += "NOT NULL " | |||
} | |||
return sql | |||
} | |||
func (b *Base) FormatBytes(bs []byte) string { | |||
return fmt.Sprintf("0x%x", bs) | |||
} | |||
@@ -178,8 +133,8 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa | |||
} | |||
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { | |||
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), | |||
db.String(col)) | |||
s, _ := ColumnString(db.dialect, col, true) | |||
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s) | |||
} | |||
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { | |||
@@ -207,7 +162,8 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { | |||
} | |||
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { | |||
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col)) | |||
s, _ := ColumnString(db.dialect, col, false) | |||
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, s) | |||
} | |||
func (b *Base) ForUpdateSQL(query string) string { | |||
@@ -266,3 +222,63 @@ func regDrvsNDialects() bool { | |||
func init() { | |||
regDrvsNDialects() | |||
} | |||
// ColumnString generate column description string according dialect | |||
func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { | |||
bd := strings.Builder{} | |||
if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { | |||
return "", err | |||
} | |||
if err := bd.WriteByte(' '); err != nil { | |||
return "", err | |||
} | |||
if _, err := bd.WriteString(dialect.SQLType(col)); err != nil { | |||
return "", err | |||
} | |||
if err := bd.WriteByte(' '); err != nil { | |||
return "", err | |||
} | |||
if includePrimaryKey && col.IsPrimaryKey { | |||
if _, err := bd.WriteString("PRIMARY KEY "); err != nil { | |||
return "", err | |||
} | |||
if col.IsAutoIncrement { | |||
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { | |||
return "", err | |||
} | |||
if err := bd.WriteByte(' '); err != nil { | |||
return "", err | |||
} | |||
} | |||
} | |||
if col.Default != "" { | |||
if _, err := bd.WriteString("DEFAULT "); err != nil { | |||
return "", err | |||
} | |||
if _, err := bd.WriteString(col.Default); err != nil { | |||
return "", err | |||
} | |||
if err := bd.WriteByte(' '); err != nil { | |||
return "", err | |||
} | |||
} | |||
if col.Nullable { | |||
if _, err := bd.WriteString("NULL "); err != nil { | |||
return "", err | |||
} | |||
} else { | |||
if _, err := bd.WriteString("NOT NULL "); err != nil { | |||
return "", err | |||
} | |||
} | |||
return bd.String(), nil | |||
} |
@@ -205,7 +205,11 @@ var ( | |||
"PROC": true, | |||
} | |||
mssqlQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve} | |||
mssqlQuoter = schemas.Quoter{ | |||
Prefix: '[', | |||
Suffix: ']', | |||
IsReserved: schemas.AlwaysReserve, | |||
} | |||
) | |||
type mssql struct { | |||
@@ -501,11 +505,8 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin | |||
for _, colName := range table.ColumnsSeq() { | |||
col := table.GetColumn(colName) | |||
if col.IsPrimaryKey && len(pkList) == 1 { | |||
sql += db.String(col) | |||
} else { | |||
sql += db.StringNoPk(col) | |||
} | |||
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) | |||
sql += s | |||
sql = strings.TrimSpace(sql) | |||
sql += ", " | |||
} | |||
@@ -162,7 +162,11 @@ var ( | |||
"ZEROFILL": true, | |||
} | |||
mysqlQuoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve} | |||
mysqlQuoter = schemas.Quoter{ | |||
Prefix: '`', | |||
Suffix: '`', | |||
IsReserved: schemas.AlwaysReserve, | |||
} | |||
) | |||
type mysql struct { | |||
@@ -293,8 +297,8 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa | |||
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { | |||
quoter := db.dialect.Quoter() | |||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), | |||
db.String(col)) | |||
s, _ := ColumnString(db, col, true) | |||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) | |||
if len(col.Comment) > 0 { | |||
sql += " COMMENT '" + col.Comment + "'" | |||
} | |||
@@ -304,7 +308,8 @@ func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { | |||
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { | |||
args := []interface{}{db.uri.DBName, tableName} | |||
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + | |||
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" | |||
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + | |||
" ORDER BY `INFORMATION_SCHEMA`.`COLUMNS`.ORDINAL_POSITION" | |||
rows, err := queryer.QueryContext(ctx, s, args...) | |||
if err != nil { | |||
@@ -525,11 +530,8 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin | |||
for _, colName := range table.ColumnsSeq() { | |||
col := table.GetColumn(colName) | |||
if col.IsPrimaryKey && len(pkList) == 1 { | |||
sql += db.String(col) | |||
} else { | |||
sql += db.StringNoPk(col) | |||
} | |||
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) | |||
sql += s | |||
sql = strings.TrimSpace(sql) | |||
if len(col.Comment) > 0 { | |||
sql += " COMMENT '" + col.Comment + "'" | |||
@@ -499,7 +499,11 @@ var ( | |||
"ZONE": true, | |||
} | |||
oracleQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve} | |||
oracleQuoter = schemas.Quoter{ | |||
Prefix: '"', | |||
Suffix: '"', | |||
IsReserved: schemas.AlwaysReserve, | |||
} | |||
) | |||
type oracle struct { | |||
@@ -572,7 +576,8 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri | |||
/*if col.IsPrimaryKey && len(pkList) == 1 { | |||
sql += col.String(b.dialect) | |||
} else {*/ | |||
sql += db.StringNoPk(col) | |||
s, _ := ColumnString(db, col, false) | |||
sql += s | |||
// } | |||
sql = strings.TrimSpace(sql) | |||
sql += ", " | |||
@@ -767,7 +767,11 @@ var ( | |||
"ZONE": true, | |||
} | |||
postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve} | |||
postgresQuoter = schemas.Quoter{ | |||
Prefix: '"', | |||
Suffix: '"', | |||
IsReserved: schemas.AlwaysReserve, | |||
} | |||
) | |||
var ( | |||
@@ -908,11 +912,8 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]st | |||
for _, colName := range table.ColumnsSeq() { | |||
col := table.GetColumn(colName) | |||
if col.IsPrimaryKey && len(pkList) == 1 { | |||
sql += db.String(col) | |||
} else { | |||
sql += db.StringNoPk(col) | |||
} | |||
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) | |||
sql += s | |||
sql = strings.TrimSpace(sql) | |||
sql += ", " | |||
} | |||
@@ -144,7 +144,11 @@ var ( | |||
"WITHOUT": true, | |||
} | |||
sqlite3Quoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve} | |||
sqlite3Quoter = schemas.Quoter{ | |||
Prefix: '`', | |||
Suffix: '`', | |||
IsReserved: schemas.AlwaysReserve, | |||
} | |||
) | |||
type sqlite3 struct { | |||
@@ -260,11 +264,8 @@ func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]str | |||
for _, colName := range table.ColumnsSeq() { | |||
col := table.GetColumn(colName) | |||
if col.IsPrimaryKey && len(pkList) == 1 { | |||
sql += db.String(col) | |||
} else { | |||
sql += db.StringNoPk(col) | |||
} | |||
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) | |||
sql += s | |||
sql = strings.TrimSpace(sql) | |||
sql += ", " | |||
} | |||
@@ -8,7 +8,7 @@ Package xorm is a simple and powerful ORM for Go. | |||
Installation | |||
Make sure you have installed Go 1.6+ and then: | |||
Make sure you have installed Go 1.11+ and then: | |||
go get xorm.io/xorm | |||
@@ -12,11 +12,13 @@ import ( | |||
"io" | |||
"os" | |||
"reflect" | |||
"runtime" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"xorm.io/xorm/caches" | |||
"xorm.io/xorm/contexts" | |||
"xorm.io/xorm/core" | |||
"xorm.io/xorm/dialects" | |||
"xorm.io/xorm/internal/utils" | |||
@@ -42,16 +44,79 @@ type Engine struct { | |||
TZLocation *time.Location // The timezone of the application | |||
DatabaseTZ *time.Location // The timezone of the database | |||
logSessionID bool // create session id | |||
} | |||
// NewEngine new a db manager according to the parameter. Currently support four | |||
// drivers | |||
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { | |||
dialect, err := dialects.OpenDialect(driverName, dataSourceName) | |||
if err != nil { | |||
return nil, err | |||
} | |||
db, err := core.Open(driverName, dataSourceName) | |||
if err != nil { | |||
return nil, err | |||
} | |||
cacherMgr := caches.NewManager() | |||
mapper := names.NewCacheMapper(new(names.SnakeMapper)) | |||
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) | |||
engine := &Engine{ | |||
dialect: dialect, | |||
TZLocation: time.Local, | |||
defaultContext: context.Background(), | |||
cacherMgr: cacherMgr, | |||
tagParser: tagParser, | |||
driverName: driverName, | |||
dataSourceName: dataSourceName, | |||
db: db, | |||
logSessionID: false, | |||
} | |||
if dialect.URI().DBType == schemas.SQLITE { | |||
engine.DatabaseTZ = time.UTC | |||
} else { | |||
engine.DatabaseTZ = time.Local | |||
} | |||
logger := log.NewSimpleLogger(os.Stdout) | |||
logger.SetLevel(log.LOG_INFO) | |||
engine.SetLogger(log.NewLoggerAdapter(logger)) | |||
runtime.SetFinalizer(engine, func(engine *Engine) { | |||
engine.Close() | |||
}) | |||
return engine, nil | |||
} | |||
// NewEngineWithParams new a db manager with params. The params will be passed to dialects. | |||
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { | |||
engine, err := NewEngine(driverName, dataSourceName) | |||
engine.dialect.SetParams(params) | |||
return engine, err | |||
} | |||
// EnableSessionID if enable session id | |||
func (engine *Engine) EnableSessionID(enable bool) { | |||
engine.logSessionID = enable | |||
} | |||
// SetCacher sets cacher for the table | |||
func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) { | |||
engine.cacherMgr.SetCacher(tableName, cacher) | |||
} | |||
// GetCacher returns the cachher of the special table | |||
func (engine *Engine) GetCacher(tableName string) caches.Cacher { | |||
return engine.cacherMgr.GetCacher(tableName) | |||
} | |||
// SetQuotePolicy sets the special quote policy | |||
func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { | |||
engine.dialect.SetQuotePolicy(quotePolicy) | |||
} | |||
@@ -222,9 +287,7 @@ func (engine *Engine) Dialect() dialects.Dialect { | |||
// NewSession New a session | |||
func (engine *Engine) NewSession() *Session { | |||
session := &Session{engine: engine} | |||
session.Init() | |||
return session | |||
return newSession(engine) | |||
} | |||
// Close the engine | |||
@@ -753,81 +816,11 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) { | |||
return session.IsTableExist(beanOrTableName) | |||
} | |||
// IDOf get id from one struct | |||
func (engine *Engine) IDOf(bean interface{}) (schemas.PK, error) { | |||
return engine.IDOfV(reflect.ValueOf(bean)) | |||
} | |||
// TableName returns table name with schema prefix if has | |||
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { | |||
return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...) | |||
} | |||
// IDOfV get id from one value of struct | |||
func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) { | |||
return engine.idOfV(rv) | |||
} | |||
func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { | |||
v := reflect.Indirect(rv) | |||
table, err := engine.tagParser.ParseWithCache(v) | |||
if err != nil { | |||
return nil, err | |||
} | |||
pk := make([]interface{}, len(table.PrimaryKeys)) | |||
for i, col := range table.PKColumns() { | |||
var err error | |||
fieldName := col.FieldName | |||
for { | |||
parts := strings.SplitN(fieldName, ".", 2) | |||
if len(parts) == 1 { | |||
break | |||
} | |||
v = v.FieldByName(parts[0]) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
if v.Kind() != reflect.Struct { | |||
return nil, ErrUnSupportedType | |||
} | |||
fieldName = parts[1] | |||
} | |||
pkField := v.FieldByName(fieldName) | |||
switch pkField.Kind() { | |||
case reflect.String: | |||
pk[i], err = engine.idTypeAssertion(col, pkField.String()) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
pk[i], err = engine.idTypeAssertion(col, strconv.FormatInt(pkField.Int(), 10)) | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
// id of uint will be converted to int64 | |||
pk[i], err = engine.idTypeAssertion(col, strconv.FormatUint(pkField.Uint(), 10)) | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
return schemas.PK(pk), nil | |||
} | |||
func (engine *Engine) idTypeAssertion(col *schemas.Column, sid string) (interface{}, error) { | |||
if col.SQLType.IsNumeric() { | |||
n, err := strconv.ParseInt(sid, 10, 64) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return n, nil | |||
} else if col.SQLType.IsText() { | |||
return sid, nil | |||
} else { | |||
return nil, errors.New("not supported") | |||
} | |||
} | |||
// CreateIndexes create indexes | |||
func (engine *Engine) CreateIndexes(bean interface{}) error { | |||
session := engine.NewSession() | |||
@@ -1225,6 +1218,10 @@ func (engine *Engine) SetSchema(schema string) { | |||
engine.dialect.URI().SetSchema(schema) | |||
} | |||
func (engine *Engine) AddHook(hook contexts.Hook) { | |||
engine.db.AddHook(hook) | |||
} | |||
// Unscoped always disable struct tag "deleted" | |||
func (engine *Engine) Unscoped() *Session { | |||
session := engine.NewSession() | |||
@@ -1236,7 +1233,7 @@ func (engine *Engine) tbNameWithSchema(v string) string { | |||
return dialects.TableNameWithSchema(engine.dialect, v) | |||
} | |||
// Context creates a session with the context | |||
// ContextHook creates a session with the context | |||
func (engine *Engine) Context(ctx context.Context) *Session { | |||
session := engine.NewSession() | |||
session.isAutoClose = true | |||
@@ -9,6 +9,7 @@ import ( | |||
"time" | |||
"xorm.io/xorm/caches" | |||
"xorm.io/xorm/contexts" | |||
"xorm.io/xorm/dialects" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/names" | |||
@@ -78,7 +79,7 @@ func (eg *EngineGroup) Close() error { | |||
return nil | |||
} | |||
// Context returned a group session | |||
// ContextHook returned a group session | |||
func (eg *EngineGroup) Context(ctx context.Context) *Session { | |||
sess := eg.NewSession() | |||
sess.isAutoClose = true | |||
@@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) { | |||
} | |||
} | |||
func (eg *EngineGroup) AddHook(hook contexts.Hook) { | |||
eg.Engine.AddHook(hook) | |||
for i := 0; i < len(eg.slaves); i++ { | |||
eg.slaves[i].AddHook(hook) | |||
} | |||
} | |||
// SetLogLevel sets the logger level | |||
func (eg *EngineGroup) SetLogLevel(level log.LogLevel) { | |||
eg.Engine.SetLogLevel(level) | |||
@@ -181,6 +189,7 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup { | |||
return eg | |||
} | |||
// SetQuotePolicy sets the special quote policy | |||
func (eg *EngineGroup) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { | |||
eg.Engine.SetQuotePolicy(quotePolicy) | |||
for i := 0; i < len(eg.slaves); i++ { | |||
@@ -51,6 +51,7 @@ func WeightRandomPolicy(weights []int) GroupPolicyHandler { | |||
} | |||
} | |||
// RoundRobinPolicy returns a group policy handler | |||
func RoundRobinPolicy() GroupPolicyHandler { | |||
var pos = -1 | |||
var lock sync.Mutex | |||
@@ -68,6 +69,7 @@ func RoundRobinPolicy() GroupPolicyHandler { | |||
} | |||
} | |||
// WeightRoundRobinPolicy returns a group policy handler | |||
func WeightRoundRobinPolicy(weights []int) GroupPolicyHandler { | |||
var rands = make([]int, 0, len(weights)) | |||
for i := 0; i < len(weights); i++ { | |||
@@ -9,6 +9,8 @@ import ( | |||
) | |||
var ( | |||
// ErrPtrSliceType represents a type error | |||
ErrPtrSliceType = errors.New("A point to a slice is needed") | |||
// ErrParamsType params error | |||
ErrParamsType = errors.New("Params type error") | |||
// ErrTableNotFound table not found error | |||
@@ -11,6 +11,7 @@ import ( | |||
"time" | |||
"xorm.io/xorm/caches" | |||
"xorm.io/xorm/contexts" | |||
"xorm.io/xorm/dialects" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/names" | |||
@@ -111,6 +112,7 @@ type EngineInterface interface { | |||
SetTableMapper(names.Mapper) | |||
SetTZDatabase(tz *time.Location) | |||
SetTZLocation(tz *time.Location) | |||
AddHook(hook contexts.Hook) | |||
ShowSQL(show ...bool) | |||
Sync(...interface{}) error | |||
Sync2(...interface{}) error | |||
@@ -5,6 +5,7 @@ | |||
package statements | |||
import ( | |||
"fmt" | |||
"strings" | |||
"xorm.io/builder" | |||
@@ -23,18 +24,15 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem | |||
return nil | |||
} | |||
// GenInsertSQL generates insert beans SQL | |||
func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { | |||
var ( | |||
buf = builder.NewWriter() | |||
exprs = statement.ExprColumns | |||
table = statement.RefTable | |||
tableName = statement.TableName() | |||
exprs = statement.ExprColumns | |||
colPlaces = strings.Repeat("?, ", len(colNames)) | |||
) | |||
if exprs.Len() <= 0 && len(colPlaces) > 0 { | |||
colPlaces = colPlaces[0 : len(colPlaces)-2] | |||
} | |||
var buf = builder.NewWriter() | |||
if _, err := buf.WriteString("INSERT INTO "); err != nil { | |||
return "", nil, err | |||
} | |||
@@ -43,7 +41,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) | |||
return "", nil, err | |||
} | |||
if len(colPlaces) <= 0 { | |||
if len(colNames) <= 0 { | |||
if statement.dialect.URI().DBType == schemas.MYSQL { | |||
if _, err := buf.WriteString(" VALUES ()"); err != nil { | |||
return "", nil, err | |||
@@ -65,13 +63,14 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) | |||
return "", nil, err | |||
} | |||
if _, err := buf.WriteString(")"); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil { | |||
return "", nil, err | |||
} | |||
if statement.Conds().IsValid() { | |||
if _, err := buf.WriteString(")"); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil { | |||
return "", nil, err | |||
} | |||
if _, err := buf.WriteString(" SELECT "); err != nil { | |||
return "", nil, err | |||
} | |||
@@ -105,21 +104,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) | |||
return "", nil, err | |||
} | |||
} else { | |||
buf.Append(args...) | |||
if _, err := buf.WriteString(")"); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil { | |||
return "", nil, err | |||
} | |||
if _, err := buf.WriteString(" VALUES ("); err != nil { | |||
return "", nil, err | |||
} | |||
if _, err := buf.WriteString(colPlaces); err != nil { | |||
if err := statement.WriteArgs(buf, args); err != nil { | |||
return "", nil, err | |||
} | |||
if len(exprs.Args) > 0 { | |||
if _, err := buf.WriteString(","); err != nil { | |||
return "", nil, err | |||
} | |||
} | |||
if err := exprs.WriteArgs(buf); err != nil { | |||
return "", nil, err | |||
} | |||
@@ -141,3 +139,69 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) | |||
return buf.String(), buf.Args(), nil | |||
} | |||
// GenInsertMapSQL generates insert map SQL | |||
func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) { | |||
var ( | |||
buf = builder.NewWriter() | |||
exprs = statement.ExprColumns | |||
tableName = statement.TableName() | |||
) | |||
if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { | |||
return "", nil, err | |||
} | |||
// if insert where | |||
if statement.Conds().IsValid() { | |||
if _, err := buf.WriteString(") SELECT "); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.WriteArgs(buf, args); err != nil { | |||
return "", nil, err | |||
} | |||
if len(exprs.Args) > 0 { | |||
if _, err := buf.WriteString(","); err != nil { | |||
return "", nil, err | |||
} | |||
if err := exprs.WriteArgs(buf); err != nil { | |||
return "", nil, err | |||
} | |||
} | |||
if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.Conds().WriteTo(buf); err != nil { | |||
return "", nil, err | |||
} | |||
} else { | |||
if _, err := buf.WriteString(") VALUES ("); err != nil { | |||
return "", nil, err | |||
} | |||
if err := statement.WriteArgs(buf, args); err != nil { | |||
return "", nil, err | |||
} | |||
if len(exprs.Args) > 0 { | |||
if _, err := buf.WriteString(","); err != nil { | |||
return "", nil, err | |||
} | |||
if err := exprs.WriteArgs(buf); err != nil { | |||
return "", nil, err | |||
} | |||
} | |||
if _, err := buf.WriteString(")"); err != nil { | |||
return "", nil, err | |||
} | |||
} | |||
return buf.String(), buf.Args(), nil | |||
} |
@@ -20,6 +20,21 @@ var ( | |||
uintType = reflect.TypeOf(uint64(0)) | |||
) | |||
// ErrIDConditionWithNoTable represents an error there is no reference table with an ID condition | |||
type ErrIDConditionWithNoTable struct { | |||
ID schemas.PK | |||
} | |||
func (err ErrIDConditionWithNoTable) Error() string { | |||
return fmt.Sprintf("ID condition %#v need reference table", err.ID) | |||
} | |||
// IsIDConditionWithNoTableErr return true if the err is ErrIDConditionWithNoTable | |||
func IsIDConditionWithNoTableErr(err error) bool { | |||
_, ok := err.(ErrIDConditionWithNoTable) | |||
return ok | |||
} | |||
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" | |||
func (statement *Statement) ID(id interface{}) *Statement { | |||
switch t := id.(type) { | |||
@@ -58,13 +73,17 @@ func (statement *Statement) ID(id interface{}) *Statement { | |||
return statement | |||
} | |||
// ProcessIDParam handles the process of id condition | |||
func (statement *Statement) ProcessIDParam() error { | |||
if statement.idParam == nil || statement.RefTable == nil { | |||
if statement.idParam == nil { | |||
return nil | |||
} | |||
if statement.RefTable == nil { | |||
return ErrIDConditionWithNoTable{statement.idParam} | |||
} | |||
if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) { | |||
fmt.Println("=====", statement.RefTable.PrimaryKeys, statement.idParam) | |||
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", | |||
len(statement.RefTable.PrimaryKeys), | |||
len(statement.idParam), | |||
@@ -797,8 +797,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, | |||
if !requiredField && fieldValue.Uint() == 0 { | |||
continue | |||
} | |||
t := int64(fieldValue.Uint()) | |||
val = reflect.ValueOf(&t).Interface() | |||
val = fieldValue.Interface() | |||
case reflect.Struct: | |||
if fieldType.ConvertibleTo(schemas.TimeType) { | |||
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) | |||
@@ -79,28 +79,6 @@ const insertSelectPlaceHolder = true | |||
func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { | |||
switch argv := arg.(type) { | |||
case bool: | |||
if statement.dialect.URI().DBType == schemas.MSSQL { | |||
if argv { | |||
if _, err := w.WriteString("1"); err != nil { | |||
return err | |||
} | |||
} else { | |||
if _, err := w.WriteString("0"); err != nil { | |||
return err | |||
} | |||
} | |||
} else { | |||
if argv { | |||
if _, err := w.WriteString("true"); err != nil { | |||
return err | |||
} | |||
} else { | |||
if _, err := w.WriteString("false"); err != nil { | |||
return err | |||
} | |||
} | |||
} | |||
case *builder.Builder: | |||
if _, err := w.WriteString("("); err != nil { | |||
return err | |||
@@ -116,7 +94,15 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er | |||
if err := w.WriteByte('?'); err != nil { | |||
return err | |||
} | |||
w.Append(arg) | |||
if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL { | |||
if v { | |||
w.Append(1) | |||
} else { | |||
w.Append(0) | |||
} | |||
} else { | |||
w.Append(arg) | |||
} | |||
} else { | |||
var convertFunc = convertStringSingleQuote | |||
if statement.dialect.URI().DBType == schemas.MYSQL { | |||
@@ -190,8 +190,7 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, | |||
if !requiredField && fieldValue.Uint() == 0 { | |||
continue | |||
} | |||
t := int64(fieldValue.Uint()) | |||
val = reflect.ValueOf(&t).Interface() | |||
val = fieldValue.Interface() | |||
case reflect.Struct: | |||
if fieldType.ConvertibleTo(schemas.TimeType) { | |||
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) | |||
@@ -147,7 +147,7 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl | |||
} | |||
return nil, ErrUnSupportedType | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | |||
return int64(fieldValue.Uint()), nil | |||
return fieldValue.Uint(), nil | |||
default: | |||
return fieldValue.Interface(), nil | |||
} | |||
@@ -5,19 +5,15 @@ | |||
package log | |||
import ( | |||
"context" | |||
"time" | |||
"fmt" | |||
"xorm.io/xorm/contexts" | |||
) | |||
// LogContext represents a log context | |||
type LogContext struct { | |||
Ctx context.Context | |||
SQL string // log content or SQL | |||
Args []interface{} // if it's a SQL, it's the arguments | |||
ExecuteTime time.Duration | |||
Err error // SQL executed error | |||
} | |||
type LogContext contexts.ContextHook | |||
// SQLLogger represents an interface to log SQL | |||
type SQLLogger interface { | |||
BeforeSQL(context LogContext) // only invoked when IsShowSQL is true | |||
AfterSQL(context LogContext) // only invoked when IsShowSQL is true | |||
@@ -43,55 +39,77 @@ var ( | |||
_ ContextLogger = &LoggerAdapter{} | |||
) | |||
// LoggerAdapter wraps a Logger interafce as LoggerContext interface | |||
// enumerate all the context keys | |||
var ( | |||
SessionIDKey = "__xorm_session_id" | |||
SessionShowSQLKey = "__xorm_show_sql" | |||
) | |||
// LoggerAdapter wraps a Logger interface as LoggerContext interface | |||
type LoggerAdapter struct { | |||
logger Logger | |||
} | |||
// NewLoggerAdapter creates an adapter for old xorm logger interface | |||
func NewLoggerAdapter(logger Logger) ContextLogger { | |||
return &LoggerAdapter{ | |||
logger: logger, | |||
} | |||
} | |||
// BeforeSQL implements ContextLogger | |||
func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {} | |||
// AfterSQL implements ContextLogger | |||
func (l *LoggerAdapter) AfterSQL(ctx LogContext) { | |||
var sessionPart string | |||
v := ctx.Ctx.Value(SessionIDKey) | |||
if key, ok := v.(string); ok { | |||
sessionPart = fmt.Sprintf(" [%s]", key) | |||
} | |||
if ctx.ExecuteTime > 0 { | |||
l.logger.Infof("[SQL] %v %v - %v", ctx.SQL, ctx.Args, ctx.ExecuteTime) | |||
l.logger.Infof("[SQL]%s %s %v - %v", sessionPart, ctx.SQL, ctx.Args, ctx.ExecuteTime) | |||
} else { | |||
l.logger.Infof("[SQL] %v %v", ctx.SQL, ctx.Args) | |||
l.logger.Infof("[SQL]%s %s %v", sessionPart, ctx.SQL, ctx.Args) | |||
} | |||
} | |||
// Debugf implements ContextLogger | |||
func (l *LoggerAdapter) Debugf(format string, v ...interface{}) { | |||
l.logger.Debugf(format, v...) | |||
} | |||
// Errorf implements ContextLogger | |||
func (l *LoggerAdapter) Errorf(format string, v ...interface{}) { | |||
l.logger.Errorf(format, v...) | |||
} | |||
// Infof implements ContextLogger | |||
func (l *LoggerAdapter) Infof(format string, v ...interface{}) { | |||
l.logger.Infof(format, v...) | |||
} | |||
// Warnf implements ContextLogger | |||
func (l *LoggerAdapter) Warnf(format string, v ...interface{}) { | |||
l.logger.Warnf(format, v...) | |||
} | |||
// Level implements ContextLogger | |||
func (l *LoggerAdapter) Level() LogLevel { | |||
return l.logger.Level() | |||
} | |||
// SetLevel implements ContextLogger | |||
func (l *LoggerAdapter) SetLevel(lv LogLevel) { | |||
l.logger.SetLevel(lv) | |||
} | |||
// ShowSQL implements ContextLogger | |||
func (l *LoggerAdapter) ShowSQL(show ...bool) { | |||
l.logger.ShowSQL(show...) | |||
} | |||
// IsShowSQL implements ContextLogger | |||
func (l *LoggerAdapter) IsShowSQL() bool { | |||
return l.logger.IsShowSQL() | |||
} |
@@ -7,6 +7,7 @@ package names | |||
import ( | |||
"strings" | |||
"sync" | |||
"unsafe" | |||
) | |||
// Mapper represents a name convertation between struct's fields name and table's column name | |||
@@ -77,19 +78,24 @@ func (m SameMapper) Table2Obj(t string) string { | |||
type SnakeMapper struct { | |||
} | |||
func b2s(b []byte) string { | |||
return *(*string)(unsafe.Pointer(&b)) | |||
} | |||
func snakeCasedName(name string) string { | |||
newstr := make([]rune, 0) | |||
for idx, chr := range name { | |||
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { | |||
if idx > 0 { | |||
newstr := make([]byte, 0, len(name)+1) | |||
for i := 0; i < len(name); i++ { | |||
c := name[i] | |||
if isUpper := 'A' <= c && c <= 'Z'; isUpper { | |||
if i > 0 { | |||
newstr = append(newstr, '_') | |||
} | |||
chr -= ('A' - 'a') | |||
c += 'a' - 'A' | |||
} | |||
newstr = append(newstr, chr) | |||
newstr = append(newstr, c) | |||
} | |||
return string(newstr) | |||
return b2s(newstr) | |||
} | |||
func (mapper SnakeMapper) Obj2Table(name string) string { | |||
@@ -97,27 +103,28 @@ func (mapper SnakeMapper) Obj2Table(name string) string { | |||
} | |||
func titleCasedName(name string) string { | |||
newstr := make([]rune, 0) | |||
newstr := make([]byte, 0, len(name)) | |||
upNextChar := true | |||
name = strings.ToLower(name) | |||
for _, chr := range name { | |||
for i := 0; i < len(name); i++ { | |||
c := name[i] | |||
switch { | |||
case upNextChar: | |||
upNextChar = false | |||
if 'a' <= chr && chr <= 'z' { | |||
chr -= ('a' - 'A') | |||
if 'a' <= c && c <= 'z' { | |||
c -= 'a' - 'A' | |||
} | |||
case chr == '_': | |||
case c == '_': | |||
upNextChar = true | |||
continue | |||
} | |||
newstr = append(newstr, chr) | |||
newstr = append(newstr, c) | |||
} | |||
return string(newstr) | |||
return b2s(newstr) | |||
} | |||
func (mapper SnakeMapper) Table2Obj(name string) string { | |||
@@ -76,3 +76,69 @@ func (session *Session) executeProcessors() error { | |||
} | |||
return nil | |||
} | |||
func cleanupProcessorsClosures(slices *[]func(interface{})) { | |||
if len(*slices) > 0 { | |||
*slices = make([]func(interface{}), 0) | |||
} | |||
} | |||
func executeBeforeClosures(session *Session, bean interface{}) { | |||
// handle before delete processors | |||
for _, closure := range session.beforeClosures { | |||
closure(bean) | |||
} | |||
cleanupProcessorsClosures(&session.beforeClosures) | |||
} | |||
func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) { | |||
if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { | |||
for ii, key := range fields { | |||
b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) | |||
} | |||
} | |||
} | |||
func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) { | |||
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { | |||
for ii, key := range fields { | |||
b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) | |||
} | |||
} | |||
} | |||
func buildAfterProcessors(session *Session, bean interface{}) { | |||
// handle afterClosures | |||
for _, closure := range session.afterClosures { | |||
session.afterProcessors = append(session.afterProcessors, executedProcessor{ | |||
fun: func(sess *Session, bean interface{}) error { | |||
closure(bean) | |||
return nil | |||
}, | |||
session: session, | |||
bean: bean, | |||
}) | |||
} | |||
if a, has := bean.(AfterLoadProcessor); has { | |||
session.afterProcessors = append(session.afterProcessors, executedProcessor{ | |||
fun: func(sess *Session, bean interface{}) error { | |||
a.AfterLoad() | |||
return nil | |||
}, | |||
session: session, | |||
bean: bean, | |||
}) | |||
} | |||
if a, has := bean.(AfterLoadSessionProcessor); has { | |||
session.afterProcessors = append(session.afterProcessors, executedProcessor{ | |||
fun: func(sess *Session, bean interface{}) error { | |||
a.AfterLoad(sess) | |||
return nil | |||
}, | |||
session: session, | |||
bean: bean, | |||
}) | |||
} | |||
} |
@@ -5,8 +5,10 @@ | |||
package schemas | |||
import ( | |||
"errors" | |||
"fmt" | |||
"reflect" | |||
"strconv" | |||
"strings" | |||
"time" | |||
) | |||
@@ -115,3 +117,17 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { | |||
return &fieldValue, nil | |||
} | |||
// ConvertID converts id content to suitable type according column type | |||
func (col *Column) ConvertID(sid string) (interface{}, error) { | |||
if col.SQLType.IsNumeric() { | |||
n, err := strconv.ParseInt(sid, 10, 64) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return n, nil | |||
} else if col.SQLType.IsText() { | |||
return sid, nil | |||
} | |||
return nil, errors.New("not supported") | |||
} |
@@ -5,7 +5,9 @@ | |||
package schemas | |||
import ( | |||
"fmt" | |||
"reflect" | |||
"strconv" | |||
"strings" | |||
) | |||
@@ -28,6 +30,7 @@ type Table struct { | |||
Comment string | |||
} | |||
// NewEmptyTable creates an empty table | |||
func NewEmptyTable() *Table { | |||
return NewTable("", nil) | |||
} | |||
@@ -44,10 +47,12 @@ func NewTable(name string, t reflect.Type) *Table { | |||
} | |||
} | |||
// Columns returns table's columns | |||
func (table *Table) Columns() []*Column { | |||
return table.columns | |||
} | |||
// ColumnsSeq returns table's column names according sequence | |||
func (table *Table) ColumnsSeq() []string { | |||
return table.columnsSeq | |||
} | |||
@@ -61,6 +66,7 @@ func (table *Table) columnsByName(name string) []*Column { | |||
return nil | |||
} | |||
// GetColumn returns column according column name, if column not found, return nil | |||
func (table *Table) GetColumn(name string) *Column { | |||
cols := table.columnsByName(name) | |||
if cols != nil { | |||
@@ -70,6 +76,7 @@ func (table *Table) GetColumn(name string) *Column { | |||
return nil | |||
} | |||
// GetColumnIdx returns column according name and idx | |||
func (table *Table) GetColumnIdx(name string, idx int) *Column { | |||
cols := table.columnsByName(name) | |||
if cols != nil && idx < len(cols) { | |||
@@ -144,3 +151,45 @@ func (table *Table) AddColumn(col *Column) { | |||
func (table *Table) AddIndex(index *Index) { | |||
table.Indexes[index.Name] = index | |||
} | |||
// IDOfV get id from one value of struct | |||
func (table *Table) IDOfV(rv reflect.Value) (PK, error) { | |||
v := reflect.Indirect(rv) | |||
pk := make([]interface{}, len(table.PrimaryKeys)) | |||
for i, col := range table.PKColumns() { | |||
var err error | |||
fieldName := col.FieldName | |||
for { | |||
parts := strings.SplitN(fieldName, ".", 2) | |||
if len(parts) == 1 { | |||
break | |||
} | |||
v = v.FieldByName(parts[0]) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
if v.Kind() != reflect.Struct { | |||
return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) | |||
} | |||
fieldName = parts[1] | |||
} | |||
pkField := v.FieldByName(fieldName) | |||
switch pkField.Kind() { | |||
case reflect.String: | |||
pk[i], err = col.ConvertID(pkField.String()) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
pk[i], err = col.ConvertID(strconv.FormatInt(pkField.Int(), 10)) | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
// id of uint will be converted to int64 | |||
pk[i], err = col.ConvertID(strconv.FormatUint(pkField.Uint(), 10)) | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
return PK(pk), nil | |||
} |
@@ -6,10 +6,14 @@ package xorm | |||
import ( | |||
"context" | |||
"crypto/rand" | |||
"crypto/sha256" | |||
"database/sql" | |||
"encoding/hex" | |||
"errors" | |||
"fmt" | |||
"hash/crc32" | |||
"io" | |||
"reflect" | |||
"strings" | |||
"time" | |||
@@ -19,6 +23,7 @@ import ( | |||
"xorm.io/xorm/core" | |||
"xorm.io/xorm/internal/json" | |||
"xorm.io/xorm/internal/statements" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/schemas" | |||
) | |||
@@ -42,24 +47,24 @@ func (e ErrFieldIsNotValid) Error() string { | |||
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) | |||
} | |||
type sessionType int | |||
type sessionType bool | |||
const ( | |||
engineSession sessionType = iota | |||
groupSession | |||
engineSession sessionType = false | |||
groupSession sessionType = true | |||
) | |||
// Session keep a pointer to sql.DB and provides all execution of all | |||
// kind of database operations. | |||
type Session struct { | |||
db *core.DB | |||
engine *Engine | |||
tx *core.Tx | |||
statement *statements.Statement | |||
isAutoCommit bool | |||
isCommitedOrRollbacked bool | |||
isAutoClose bool | |||
isClosed bool | |||
prepareStmt bool | |||
// Automatically reset the statement after operations that execute a SQL | |||
// query such as Count(), Find(), Get(), ... | |||
autoResetStatement bool | |||
@@ -70,81 +75,101 @@ type Session struct { | |||
afterDeleteBeans map[interface{}]*[]func(interface{}) | |||
// -- | |||
beforeClosures []func(interface{}) | |||
afterClosures []func(interface{}) | |||
beforeClosures []func(interface{}) | |||
afterClosures []func(interface{}) | |||
afterProcessors []executedProcessor | |||
prepareStmt bool | |||
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) | |||
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) | |||
lastSQL string | |||
lastSQLArgs []interface{} | |||
showSQL bool | |||
ctx context.Context | |||
sessionType sessionType | |||
} | |||
// Clone copy all the session's content and return a new session | |||
func (session *Session) Clone() *Session { | |||
var sess = *session | |||
return &sess | |||
func newSessionID() string { | |||
hash := sha256.New() | |||
_, err := io.CopyN(hash, rand.Reader, 50) | |||
if err != nil { | |||
return "????????????????????" | |||
} | |||
md := hash.Sum(nil) | |||
mdStr := hex.EncodeToString(md) | |||
return mdStr[0:20] | |||
} | |||
// Init reset the session as the init status. | |||
func (session *Session) Init() { | |||
session.statement = statements.NewStatement( | |||
session.engine.dialect, | |||
session.engine.tagParser, | |||
session.engine.DatabaseTZ, | |||
) | |||
session.db = session.engine.db | |||
session.isAutoCommit = true | |||
session.isCommitedOrRollbacked = false | |||
session.isAutoClose = false | |||
session.autoResetStatement = true | |||
session.prepareStmt = false | |||
// !nashtsai! is lazy init better? | |||
session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) | |||
session.afterUpdateBeans = make(map[interface{}]*[]func(interface{}), 0) | |||
session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) | |||
session.beforeClosures = make([]func(interface{}), 0) | |||
session.afterClosures = make([]func(interface{}), 0) | |||
session.stmtCache = make(map[uint32]*core.Stmt) | |||
session.afterProcessors = make([]executedProcessor, 0) | |||
session.lastSQL = "" | |||
session.lastSQLArgs = []interface{}{} | |||
func newSession(engine *Engine) *Session { | |||
var ctx context.Context | |||
if engine.logSessionID { | |||
ctx = context.WithValue(engine.defaultContext, log.SessionIDKey, newSessionID()) | |||
} else { | |||
ctx = engine.defaultContext | |||
} | |||
session.ctx = session.engine.defaultContext | |||
return &Session{ | |||
ctx: ctx, | |||
engine: engine, | |||
tx: nil, | |||
statement: statements.NewStatement( | |||
engine.dialect, | |||
engine.tagParser, | |||
engine.DatabaseTZ, | |||
), | |||
isClosed: false, | |||
isAutoCommit: true, | |||
isCommitedOrRollbacked: false, | |||
isAutoClose: false, | |||
autoResetStatement: true, | |||
prepareStmt: false, | |||
afterInsertBeans: make(map[interface{}]*[]func(interface{}), 0), | |||
afterUpdateBeans: make(map[interface{}]*[]func(interface{}), 0), | |||
afterDeleteBeans: make(map[interface{}]*[]func(interface{}), 0), | |||
beforeClosures: make([]func(interface{}), 0), | |||
afterClosures: make([]func(interface{}), 0), | |||
afterProcessors: make([]executedProcessor, 0), | |||
stmtCache: make(map[uint32]*core.Stmt), | |||
lastSQL: "", | |||
lastSQLArgs: make([]interface{}, 0), | |||
sessionType: engineSession, | |||
} | |||
} | |||
// Close release the connection from pool | |||
func (session *Session) Close() { | |||
func (session *Session) Close() error { | |||
for _, v := range session.stmtCache { | |||
v.Close() | |||
if err := v.Close(); err != nil { | |||
return err | |||
} | |||
} | |||
if session.db != nil { | |||
if !session.isClosed { | |||
// When Close be called, if session is a transaction and do not call | |||
// Commit or Rollback, then call Rollback. | |||
if session.tx != nil && !session.isCommitedOrRollbacked { | |||
session.Rollback() | |||
if err := session.Rollback(); err != nil { | |||
return err | |||
} | |||
} | |||
session.tx = nil | |||
session.stmtCache = nil | |||
session.db = nil | |||
session.isClosed = true | |||
} | |||
return nil | |||
} | |||
func (session *Session) db() *core.DB { | |||
return session.engine.db | |||
} | |||
func (session *Session) getQueryer() core.Queryer { | |||
if session.tx != nil { | |||
return session.tx | |||
} | |||
return session.db | |||
return session.db() | |||
} | |||
// ContextCache enable context cache or not | |||
@@ -155,7 +180,7 @@ func (session *Session) ContextCache(context contexts.ContextCache) *Session { | |||
// IsClosed returns if session is closed | |||
func (session *Session) IsClosed() bool { | |||
return session.db == nil | |||
return session.isClosed | |||
} | |||
func (session *Session) resetStatement() { | |||
@@ -264,12 +289,12 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { | |||
} | |||
// MustLogSQL means record SQL or not and don't follow engine's setting | |||
func (session *Session) MustLogSQL(log ...bool) *Session { | |||
func (session *Session) MustLogSQL(logs ...bool) *Session { | |||
var showSQL = true | |||
if len(log) > 0 { | |||
showSQL = log[0] | |||
if len(logs) > 0 { | |||
showSQL = logs[0] | |||
} | |||
session.ctx = context.WithValue(session.ctx, "__xorm_show_sql", showSQL) | |||
session.ctx = context.WithValue(session.ctx, log.SessionShowSQLKey, showSQL) | |||
return session | |||
} | |||
@@ -300,17 +325,7 @@ func (session *Session) Having(conditions string) *Session { | |||
// DB db return the wrapper of sql.DB | |||
func (session *Session) DB() *core.DB { | |||
if session.db == nil { | |||
session.db = session.engine.DB() | |||
session.stmtCache = make(map[uint32]*core.Stmt, 0) | |||
} | |||
return session.db | |||
} | |||
func cleanupProcessorsClosures(slices *[]func(interface{})) { | |||
if len(*slices) > 0 { | |||
*slices = make([]func(interface{}), 0) | |||
} | |||
return session.db() | |||
} | |||
func (session *Session) canCache() bool { | |||
@@ -404,56 +419,17 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa | |||
return nil, err | |||
} | |||
if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { | |||
for ii, key := range fields { | |||
b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) | |||
} | |||
} | |||
executeBeforeSet(bean, fields, scanResults) | |||
return scanResults, nil | |||
} | |||
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { | |||
defer func() { | |||
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { | |||
for ii, key := range fields { | |||
b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) | |||
} | |||
} | |||
executeAfterSet(bean, fields, scanResults) | |||
}() | |||
// handle afterClosures | |||
for _, closure := range session.afterClosures { | |||
session.afterProcessors = append(session.afterProcessors, executedProcessor{ | |||
fun: func(sess *Session, bean interface{}) error { | |||
closure(bean) | |||
return nil | |||
}, | |||
session: session, | |||
bean: bean, | |||
}) | |||
} | |||
if a, has := bean.(AfterLoadProcessor); has { | |||
session.afterProcessors = append(session.afterProcessors, executedProcessor{ | |||
fun: func(sess *Session, bean interface{}) error { | |||
a.AfterLoad() | |||
return nil | |||
}, | |||
session: session, | |||
bean: bean, | |||
}) | |||
} | |||
if a, has := bean.(AfterLoadSessionProcessor); has { | |||
session.afterProcessors = append(session.afterProcessors, executedProcessor{ | |||
fun: func(sess *Session, bean interface{}) error { | |||
a.AfterLoad(sess) | |||
return nil | |||
}, | |||
session: session, | |||
bean: bean, | |||
}) | |||
} | |||
buildAfterProcessors(session, bean) | |||
var tempMap = make(map[string]int) | |||
var pk schemas.PK | |||
@@ -911,7 +887,7 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { | |||
} | |||
} | |||
// Context sets the context on this session | |||
// ContextHook sets the context on this session | |||
func (session *Session) Context(ctx context.Context) *Session { | |||
session.ctx = ctx | |||
return session | |||
@@ -96,11 +96,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { | |||
return 0, err | |||
} | |||
// handle before delete processors | |||
for _, closure := range session.beforeClosures { | |||
closure(bean) | |||
} | |||
cleanupProcessorsClosures(&session.beforeClosures) | |||
executeBeforeClosures(session, bean) | |||
if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { | |||
processor.BeforeDelete() | |||
@@ -60,6 +60,12 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte | |||
if session.statement.OrderStr != "" { | |||
session.statement.OrderStr = "" | |||
} | |||
if session.statement.LimitN != nil { | |||
session.statement.LimitN = nil | |||
} | |||
if session.statement.Start > 0 { | |||
session.statement.Start = 0 | |||
} | |||
// session has stored the conditions so we use `unscoped` to avoid duplicated condition. | |||
return session.Unscoped().Count(reflect.New(sliceElementType).Interface()) | |||
@@ -108,8 +114,11 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||
) | |||
if tp == tpStruct { | |||
if !session.statement.NoAutoCondition && len(condiBean) > 0 { | |||
var err error | |||
autoCond, err = session.statement.BuildConds(table, condiBean[0], true, true, false, true, addedTableName) | |||
condTable, err := session.engine.tagParser.Parse(reflect.ValueOf(condiBean[0])) | |||
if err != nil { | |||
return err | |||
} | |||
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, addedTableName) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -317,7 +326,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||
} | |||
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) | |||
for i, col := range table.PKColumns() { | |||
pk[i], err = session.engine.idTypeAssertion(col, res[i]) | |||
pk[i], err = col.ConvertID(res[i]) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -367,7 +376,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||
} else { | |||
session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean) | |||
pk, err := session.engine.IDOf(bean) | |||
pk, err := table.IDOfV(reflect.ValueOf(bean)) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -416,7 +425,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||
if err != nil { | |||
return err | |||
} | |||
session.statement = statement | |||
vs := reflect.Indirect(reflect.ValueOf(beans)) | |||
@@ -425,7 +433,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||
if rv.Kind() != reflect.Ptr { | |||
rv = rv.Addr() | |||
} | |||
id, err := session.engine.idOfV(rv) | |||
id, err := table.IDOfV(rv) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -242,7 +242,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, | |||
if err != nil { | |||
return false, err | |||
} | |||
// close it before covert data | |||
// close it before convert data | |||
rows.Close() | |||
dataStruct := utils.ReflectValue(bean) | |||
@@ -12,7 +12,6 @@ import ( | |||
"strconv" | |||
"strings" | |||
"xorm.io/builder" | |||
"xorm.io/xorm/internal/utils" | |||
"xorm.io/xorm/schemas" | |||
) | |||
@@ -112,13 +111,14 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
return 0, ErrTableNotFound | |||
} | |||
table := session.statement.RefTable | |||
size := sliceValue.Len() | |||
var colNames []string | |||
var colMultiPlaces []string | |||
var args []interface{} | |||
var cols []*schemas.Column | |||
var ( | |||
table = session.statement.RefTable | |||
size = sliceValue.Len() | |||
colNames []string | |||
colMultiPlaces []string | |||
args []interface{} | |||
cols []*schemas.Column | |||
) | |||
for i := 0; i < size; i++ { | |||
v := sliceValue.Index(i) | |||
@@ -233,7 +233,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
for _, closure := range session.afterClosures { | |||
closure(elemValue) | |||
} | |||
if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok { | |||
if processor, ok := elemValue.(AfterInsertProcessor); ok { | |||
processor.AfterInsert() | |||
} | |||
} else { | |||
@@ -246,7 +246,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
session.afterInsertBeans[elemValue] = &afterClosures | |||
} | |||
} else { | |||
if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok { | |||
if _, ok := elemValue.(AfterInsertProcessor); ok { | |||
session.afterInsertBeans[elemValue] = nil | |||
} | |||
} | |||
@@ -265,12 +265,11 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { | |||
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) | |||
if sliceValue.Kind() != reflect.Slice { | |||
return 0, ErrParamsType | |||
return 0, ErrPtrSliceType | |||
} | |||
if sliceValue.Len() <= 0 { | |||
return 0, nil | |||
return 0, ErrNoElementsOnSlice | |||
} | |||
return session.innerInsertMulti(rowsSlicePtr) | |||
@@ -483,7 +482,7 @@ func (session *Session) cacheInsert(table string) error { | |||
if cacher == nil { | |||
return nil | |||
} | |||
session.engine.logger.Debugf("[cache] clear sql: %v", table) | |||
session.engine.logger.Debugf("[cache] clear SQL: %v", table) | |||
cacher.ClearIds(table) | |||
return nil | |||
} | |||
@@ -623,74 +622,11 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, | |||
return 0, ErrTableNotFound | |||
} | |||
exprs := session.statement.ExprColumns | |||
w := builder.NewWriter() | |||
// if insert where | |||
if session.statement.Conds().IsValid() { | |||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { | |||
return 0, err | |||
} | |||
if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { | |||
return 0, err | |||
} | |||
if _, err := w.WriteString(") SELECT "); err != nil { | |||
return 0, err | |||
} | |||
if err := session.statement.WriteArgs(w, args); err != nil { | |||
return 0, err | |||
} | |||
if len(exprs.Args) > 0 { | |||
if _, err := w.WriteString(","); err != nil { | |||
return 0, err | |||
} | |||
if err := exprs.WriteArgs(w); err != nil { | |||
return 0, err | |||
} | |||
} | |||
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { | |||
return 0, err | |||
} | |||
if err := session.statement.Conds().WriteTo(w); err != nil { | |||
return 0, err | |||
} | |||
} else { | |||
qm := strings.Repeat("?,", len(columns)) | |||
qm = qm[:len(qm)-1] | |||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { | |||
return 0, err | |||
} | |||
if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { | |||
return 0, err | |||
} | |||
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil { | |||
return 0, err | |||
} | |||
w.Append(args...) | |||
if len(exprs.Args) > 0 { | |||
if _, err := w.WriteString(","); err != nil { | |||
return 0, err | |||
} | |||
if err := exprs.WriteArgs(w); err != nil { | |||
return 0, err | |||
} | |||
} | |||
if _, err := w.WriteString(")"); err != nil { | |||
return 0, err | |||
} | |||
sql, args, err := session.statement.GenInsertMapSQL(columns, args) | |||
if err != nil { | |||
return 0, err | |||
} | |||
sql := w.String() | |||
args = w.Args() | |||
if err := session.cacheInsert(tableName); err != nil { | |||
return 0, err | |||
} | |||
@@ -115,6 +115,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { | |||
t := v.Type() | |||
if t.Kind() == reflect.Ptr { | |||
t = t.Elem() | |||
v = v.Elem() | |||
} | |||
if t.Kind() != reflect.Struct { | |||
return nil, ErrUnsupportedType | |||
@@ -1,81 +0,0 @@ | |||
// Copyright 2015 The Xorm Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// +build go1.11 | |||
package xorm | |||
import ( | |||
"context" | |||
"os" | |||
"runtime" | |||
"time" | |||
"xorm.io/xorm/caches" | |||
"xorm.io/xorm/core" | |||
"xorm.io/xorm/dialects" | |||
"xorm.io/xorm/log" | |||
"xorm.io/xorm/names" | |||
"xorm.io/xorm/schemas" | |||
"xorm.io/xorm/tags" | |||
) | |||
func close(engine *Engine) { | |||
engine.Close() | |||
} | |||
// NewEngine new a db manager according to the parameter. Currently support four | |||
// drivers | |||
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { | |||
dialect, err := dialects.OpenDialect(driverName, dataSourceName) | |||
if err != nil { | |||
return nil, err | |||
} | |||
db, err := core.Open(driverName, dataSourceName) | |||
if err != nil { | |||
return nil, err | |||
} | |||
cacherMgr := caches.NewManager() | |||
mapper := names.NewCacheMapper(new(names.SnakeMapper)) | |||
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) | |||
engine := &Engine{ | |||
dialect: dialect, | |||
TZLocation: time.Local, | |||
defaultContext: context.Background(), | |||
cacherMgr: cacherMgr, | |||
tagParser: tagParser, | |||
driverName: driverName, | |||
dataSourceName: dataSourceName, | |||
db: db, | |||
} | |||
if dialect.URI().DBType == schemas.SQLITE { | |||
engine.DatabaseTZ = time.UTC | |||
} else { | |||
engine.DatabaseTZ = time.Local | |||
} | |||
logger := log.NewSimpleLogger(os.Stdout) | |||
logger.SetLevel(log.LOG_INFO) | |||
engine.SetLogger(log.NewLoggerAdapter(logger)) | |||
runtime.SetFinalizer(engine, close) | |||
return engine, nil | |||
} | |||
// NewEngineWithParams new a db manager with params. The params will be passed to dialects. | |||
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { | |||
engine, err := NewEngine(driverName, dataSourceName) | |||
engine.dialect.SetParams(params) | |||
return engine, err | |||
} | |||
// Clone clone an engine | |||
func (engine *Engine) Clone() (*Engine, error) { | |||
return NewEngine(engine.DriverName(), engine.DataSourceName()) | |||
} |