* feat: add unit test workflow * feat:the ability to automatically run unit tests after creating a pull request. * feat:the ability to automatically run unit tests after creating a pull request. * feat:the ability to automatically run unit tests after creating a pull request. * feat:the ability to automatically run unit tests after creating a pull request. * feat:the ability to automatically run unit tests after creating a pull request. * feat:the ability to automatically run unit tests after creating a pull request. * Optimize/at build lock key performance (#837) * Refer to buildlockkey2 optimization #829 * Time complexity O(NM)-> O(NK) about buildlockkey and buildlockkey2 Increased readability #829 * update import sort #829 * update Encapsulation into util packages #829 * Support Update join (#761) * duplicate image row for update join * update join condition placeholder param error * update join bugfix * Open test annotations * recover update executor * recover update test * recover update test * modified version param --------- Co-authored-by: JayLiu <38887641+luky116@users.noreply.github.com> Co-authored-by: FengZhang <zfcode@qq.com> --------- Co-authored-by: jimin <slievrly@163.com> Co-authored-by: JayLiu <38887641+luky116@users.noreply.github.com> Co-authored-by: FengZhang <zfcode@qq.com> Co-authored-by: Wiggins <125641755+MinatoWu@users.noreply.github.com> Co-authored-by: lxfeng1997 <33981743+lxfeng1997@users.noreply.github.com>develop-tmp
@@ -17,7 +17,7 @@ | |||||
# This is a workflow to help you test the unit case and show codecov | # This is a workflow to help you test the unit case and show codecov | ||||
name: "build and codecov" | |||||
name: "build" | |||||
on: | on: | ||||
push: | push: | ||||
@@ -60,9 +60,3 @@ jobs: | |||||
- name: "run go build" | - name: "run go build" | ||||
run: go build -v ./... | run: go build -v ./... | ||||
- name: "run go test and out codecov" | |||||
run: go test -v ./... -race -coverprofile=coverage.out -covermode=atomic | |||||
- name: "upload coverage" | |||||
uses: codecov/codecov-action@v3 |
@@ -0,0 +1,78 @@ | |||||
# | |||||
# Licensed to the Apache Software Foundation (ASF) under one or more | |||||
# contributor license agreements. See the NOTICE file distributed with | |||||
# this work for additional information regarding copyright ownership. | |||||
# The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
# (the "License"); you may not use this file except in compliance with | |||||
# the License. You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# | |||||
name: "Unit Test" | |||||
on: | |||||
push: | |||||
branches: [ master ] | |||||
pull_request: | |||||
branches: [ "*" ] | |||||
types: [opened, synchronize, reopened] | |||||
permissions: | |||||
contents: read | |||||
jobs: | |||||
unit-test: | |||||
name: Unit Test | |||||
runs-on: ubuntu-latest | |||||
timeout-minutes: 10 | |||||
strategy: | |||||
matrix: | |||||
golang: | |||||
- 1.18 | |||||
steps: | |||||
- name: "Set up Go" | |||||
uses: actions/setup-go@v3 | |||||
with: | |||||
go-version: ${{ matrix.golang }} | |||||
- name: "Checkout code" | |||||
uses: actions/checkout@v3 | |||||
with: | |||||
submodules: true | |||||
- name: "Cache dependencies" | |||||
uses: actions/cache@v3 | |||||
with: | |||||
path: ~/go/pkg/mod | |||||
key: "${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}" | |||||
restore-keys: | | |||||
"${{ runner.os }}-go-" | |||||
- name: Shutdown default mysql | |||||
run: sudo service mysql stop | |||||
- name: "Run Unit Tests" | |||||
run: | | |||||
echo "=== Starting Unit Tests ===" | |||||
go test -v ./... -race -coverprofile=coverage.txt -covermode=atomic -timeout 10m | |||||
if [ $? -ne 0 ]; then | |||||
echo "❌ Unit tests failed" | |||||
exit 1 | |||||
fi | |||||
echo "✅ Unit tests completed successfully" | |||||
- name: "Archive test results" | |||||
uses: actions/upload-artifact@v3 | |||||
with: | |||||
name: test-results | |||||
path: coverage.txt | |||||
retention-days: 7 |
@@ -18,20 +18,21 @@ | |||||
package at | package at | ||||
import ( | import ( | ||||
"bytes" | |||||
"context" | "context" | ||||
"database/sql" | "database/sql" | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"fmt" | "fmt" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/undo" | |||||
"strings" | "strings" | ||||
"github.com/arana-db/parser/ast" | "github.com/arana-db/parser/ast" | ||||
"github.com/arana-db/parser/model" | |||||
"github.com/arana-db/parser/test_driver" | "github.com/arana-db/parser/test_driver" | ||||
gxsort "github.com/dubbogo/gost/sort" | gxsort "github.com/dubbogo/gost/sort" | ||||
"github.com/pkg/errors" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/exec" | "seata.apache.org/seata-go/pkg/datasource/sql/exec" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | "seata.apache.org/seata-go/pkg/datasource/sql/types" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/undo" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/util" | "seata.apache.org/seata-go/pkg/datasource/sql/util" | ||||
"seata.apache.org/seata-go/pkg/util/reflectx" | "seata.apache.org/seata-go/pkg/util/reflectx" | ||||
) | ) | ||||
@@ -98,7 +99,13 @@ func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.Named | |||||
selectArgs = make([]driver.NamedValue, 0) | selectArgs = make([]driver.NamedValue, 0) | ||||
) | ) | ||||
b.traversalArgs(stmt.From.TableRefs, &selectArgsIndexs) | |||||
b.traversalArgs(stmt.Where, &selectArgsIndexs) | b.traversalArgs(stmt.Where, &selectArgsIndexs) | ||||
if stmt.GroupBy != nil { | |||||
for _, item := range stmt.GroupBy.Items { | |||||
b.traversalArgs(item, &selectArgsIndexs) | |||||
} | |||||
} | |||||
if stmt.OrderBy != nil { | if stmt.OrderBy != nil { | ||||
for _, item := range stmt.OrderBy.Items { | for _, item := range stmt.OrderBy.Items { | ||||
b.traversalArgs(item, &selectArgsIndexs) | b.traversalArgs(item, &selectArgsIndexs) | ||||
@@ -143,6 +150,16 @@ func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { | |||||
b.traversalArgs(exprs[i], argsIndex) | b.traversalArgs(exprs[i], argsIndex) | ||||
} | } | ||||
break | break | ||||
case *ast.Join: | |||||
exprs := node.(*ast.Join) | |||||
b.traversalArgs(exprs.Left, argsIndex) | |||||
if exprs.Right != nil { | |||||
b.traversalArgs(exprs.Right, argsIndex) | |||||
} | |||||
if exprs.On != nil { | |||||
b.traversalArgs(exprs.On.Expr, argsIndex) | |||||
} | |||||
break | |||||
case *test_driver.ParamMarkerExpr: | case *test_driver.ParamMarkerExpr: | ||||
*argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) | *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) | ||||
break | break | ||||
@@ -230,6 +247,64 @@ func (b *baseExecutor) containsPKByName(meta *types.TableMeta, columns []string) | |||||
return matchCounter == len(pkColumnNameList) | return matchCounter == len(pkColumnNameList) | ||||
} | } | ||||
func (u *baseExecutor) buildSelectFields(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, inUseFields []*ast.Assignment) ([]*ast.SelectField, error) { | |||||
fields := make([]*ast.SelectField, 0, len(inUseFields)) | |||||
tableName := tableAliases | |||||
if tableAliases == "" { | |||||
tableName = tableMeta.TableName | |||||
} | |||||
if undo.UndoConfig.OnlyCareUpdateColumns { | |||||
for _, column := range inUseFields { | |||||
tn := column.Column.Table.O | |||||
if tn != "" && tn != tableName { | |||||
continue | |||||
} | |||||
fields = append(fields, &ast.SelectField{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: column.Column, | |||||
}, | |||||
}) | |||||
} | |||||
if len(fields) == 0 { | |||||
return fields, nil | |||||
} | |||||
// select indexes columns | |||||
for _, columnName := range tableMeta.GetPrimaryKeyOnlyName() { | |||||
fields = append(fields, &ast.SelectField{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: &ast.ColumnName{ | |||||
Table: model.CIStr{ | |||||
O: tableName, | |||||
L: tableName, | |||||
}, | |||||
Name: model.CIStr{ | |||||
O: columnName, | |||||
L: columnName, | |||||
}, | |||||
}, | |||||
}, | |||||
}) | |||||
} | |||||
} else { | |||||
fields = append(fields, &ast.SelectField{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: &ast.ColumnName{ | |||||
Name: model.CIStr{ | |||||
O: "*", | |||||
L: "*", | |||||
}, | |||||
}, | |||||
}, | |||||
}) | |||||
} | |||||
return fields, nil | |||||
} | |||||
func getSqlNullValue(value interface{}) interface{} { | func getSqlNullValue(value interface{}) interface{} { | ||||
if value == nil { | if value == nil { | ||||
return nil | return nil | ||||
@@ -359,37 +434,25 @@ func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string) | |||||
// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" | // the string as local key. the local key example(multi pk): "t_user:1_a,2_b" | ||||
func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { | func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { | ||||
var ( | |||||
lockKeys bytes.Buffer | |||||
filedSequence int | |||||
) | |||||
lockKeys.WriteString(meta.TableName) | |||||
lockKeys.WriteString(":") | |||||
return util.BuildLockKey(records, meta) | |||||
} | |||||
keys := meta.GetPrimaryKeyOnlyName() | |||||
func (b *baseExecutor) rowsPrepare(ctx context.Context, conn driver.Conn, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) { | |||||
var queryer driver.Queryer | |||||
for _, row := range records.Rows { | |||||
if filedSequence > 0 { | |||||
lockKeys.WriteString(",") | |||||
} | |||||
pkSplitIndex := 0 | |||||
for _, column := range row.Columns { | |||||
var hasKeyColumn bool | |||||
for _, key := range keys { | |||||
if column.ColumnName == key { | |||||
hasKeyColumn = true | |||||
if pkSplitIndex > 0 { | |||||
lockKeys.WriteString("_") | |||||
} | |||||
lockKeys.WriteString(fmt.Sprintf("%v", column.Value)) | |||||
pkSplitIndex++ | |||||
} | |||||
} | |||||
if hasKeyColumn { | |||||
filedSequence++ | |||||
} | |||||
} | |||||
queryerContext, ok := conn.(driver.QueryerContext) | |||||
if !ok { | |||||
queryer, ok = conn.(driver.Queryer) | |||||
} | } | ||||
if ok { | |||||
var err error | |||||
rows, err = util.CtxDriverQuery(ctx, queryerContext, queryer, selectSQL, selectArgs) | |||||
return lockKeys.String() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} else { | |||||
return nil, errors.New("target conn should been driver.QueryerContext or driver.Queryer") | |||||
} | |||||
return rows, nil | |||||
} | } |
@@ -0,0 +1,212 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package at | |||||
import ( | |||||
"github.com/stretchr/testify/assert" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | |||||
"testing" | |||||
) | |||||
func TestBaseExecBuildLockKey(t *testing.T) { | |||||
var exec baseExecutor | |||||
columnID := types.ColumnMeta{ | |||||
ColumnName: "id", | |||||
} | |||||
columnUserId := types.ColumnMeta{ | |||||
ColumnName: "userId", | |||||
} | |||||
columnName := types.ColumnMeta{ | |||||
ColumnName: "name", | |||||
} | |||||
columnAge := types.ColumnMeta{ | |||||
ColumnName: "age", | |||||
} | |||||
columnNonExistent := types.ColumnMeta{ | |||||
ColumnName: "non_existent", | |||||
} | |||||
columnsTwoPk := []types.ColumnMeta{columnID, columnUserId} | |||||
columnsThreePk := []types.ColumnMeta{columnID, columnUserId, columnAge} | |||||
columnsMixPk := []types.ColumnMeta{columnName, columnAge} | |||||
getColumnImage := func(columnName string, value interface{}) types.ColumnImage { | |||||
return types.ColumnImage{KeyType: types.IndexTypePrimaryKey, ColumnName: columnName, Value: value} | |||||
} | |||||
tests := []struct { | |||||
name string | |||||
metaData types.TableMeta | |||||
records types.RecordImage | |||||
expected string | |||||
}{ | |||||
{ | |||||
"Two Primary Keys", | |||||
types.TableMeta{ | |||||
TableName: "test_name", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsTwoPk}, | |||||
}, | |||||
}, | |||||
types.RecordImage{ | |||||
TableName: "test_name", | |||||
Rows: []types.RowImage{ | |||||
{[]types.ColumnImage{getColumnImage("id", 1), getColumnImage("userId", "user1")}}, | |||||
{[]types.ColumnImage{getColumnImage("id", 2), getColumnImage("userId", "user2")}}, | |||||
}, | |||||
}, | |||||
"test_name:1_user1,2_user2", | |||||
}, | |||||
{ | |||||
"Three Primary Keys", | |||||
types.TableMeta{ | |||||
TableName: "test2_name", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsThreePk}, | |||||
}, | |||||
}, | |||||
types.RecordImage{ | |||||
TableName: "test2_name", | |||||
Rows: []types.RowImage{ | |||||
{[]types.ColumnImage{getColumnImage("id", 1), getColumnImage("userId", "one"), getColumnImage("age", "11")}}, | |||||
{[]types.ColumnImage{getColumnImage("id", 2), getColumnImage("userId", "two"), getColumnImage("age", "22")}}, | |||||
{[]types.ColumnImage{getColumnImage("id", 3), getColumnImage("userId", "three"), getColumnImage("age", "33")}}, | |||||
}, | |||||
}, | |||||
"test2_name:1_one_11,2_two_22,3_three_33", | |||||
}, | |||||
{ | |||||
name: "Single Primary Key", | |||||
metaData: types.TableMeta{ | |||||
TableName: "single_key", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{ | |||||
TableName: "single_key", | |||||
Rows: []types.RowImage{ | |||||
{Columns: []types.ColumnImage{getColumnImage("id", 100)}}, | |||||
}, | |||||
}, | |||||
expected: "single_key:100", | |||||
}, | |||||
{ | |||||
name: "Mixed Type Keys", | |||||
metaData: types.TableMeta{ | |||||
TableName: "mixed_key", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsMixPk}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{ | |||||
TableName: "mixed_key", | |||||
Rows: []types.RowImage{ | |||||
{Columns: []types.ColumnImage{getColumnImage("name", "mike"), getColumnImage("age", 25)}}, | |||||
}, | |||||
}, | |||||
expected: "mixed_key:mike_25", | |||||
}, | |||||
{ | |||||
name: "Empty Records", | |||||
metaData: types.TableMeta{ | |||||
TableName: "empty", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{TableName: "empty"}, | |||||
expected: "empty:", | |||||
}, | |||||
{ | |||||
name: "Special Characters", | |||||
metaData: types.TableMeta{ | |||||
TableName: "special", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{ | |||||
TableName: "special", | |||||
Rows: []types.RowImage{ | |||||
{Columns: []types.ColumnImage{getColumnImage("id", "A,b_c")}}, | |||||
}, | |||||
}, | |||||
expected: "special:A,b_c", | |||||
}, | |||||
{ | |||||
name: "Non-existent Key Name", | |||||
metaData: types.TableMeta{ | |||||
TableName: "error_key", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnNonExistent}}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{ | |||||
TableName: "error_key", | |||||
Rows: []types.RowImage{ | |||||
{Columns: []types.ColumnImage{getColumnImage("id", 1)}}, | |||||
}, | |||||
}, | |||||
expected: "error_key:", | |||||
}, | |||||
{ | |||||
name: "Multiple Rows With Nil PK Value", | |||||
metaData: types.TableMeta{ | |||||
TableName: "nil_pk", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{ | |||||
TableName: "nil_pk", | |||||
Rows: []types.RowImage{ | |||||
{Columns: []types.ColumnImage{getColumnImage("id", nil)}}, | |||||
{Columns: []types.ColumnImage{getColumnImage("id", 123)}}, | |||||
{Columns: []types.ColumnImage{getColumnImage("id", nil)}}, | |||||
}, | |||||
}, | |||||
expected: "nil_pk:,123,", | |||||
}, | |||||
{ | |||||
name: "PK As Bool And Float", | |||||
metaData: types.TableMeta{ | |||||
TableName: "type_pk", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnName, columnAge}}, | |||||
}, | |||||
}, | |||||
records: types.RecordImage{ | |||||
TableName: "type_pk", | |||||
Rows: []types.RowImage{ | |||||
{Columns: []types.ColumnImage{getColumnImage("name", true), getColumnImage("age", 3.14)}}, | |||||
{Columns: []types.ColumnImage{getColumnImage("name", false), getColumnImage("age", 0.0)}}, | |||||
}, | |||||
}, | |||||
expected: "type_pk:true_3.14,false_0", | |||||
}, | |||||
} | |||||
for _, tt := range tests { | |||||
t.Run(tt.name, func(t *testing.T) { | |||||
lockKeys := exec.buildLockKey(&tt.records, tt.metaData) | |||||
assert.Equal(t, tt.expected, lockKeys) | |||||
}) | |||||
} | |||||
} |
@@ -21,17 +21,17 @@ import ( | |||||
"context" | "context" | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"fmt" | "fmt" | ||||
"github.com/arana-db/parser/model" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/util" | |||||
"strings" | "strings" | ||||
"github.com/arana-db/parser/ast" | "github.com/arana-db/parser/ast" | ||||
"github.com/arana-db/parser/format" | "github.com/arana-db/parser/format" | ||||
"github.com/arana-db/parser/model" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/datasource" | "seata.apache.org/seata-go/pkg/datasource/sql/datasource" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/exec" | "seata.apache.org/seata-go/pkg/datasource/sql/exec" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | "seata.apache.org/seata-go/pkg/datasource/sql/types" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/undo" | "seata.apache.org/seata-go/pkg/datasource/sql/undo" | ||||
"seata.apache.org/seata-go/pkg/datasource/sql/util" | |||||
"seata.apache.org/seata-go/pkg/util/bytes" | "seata.apache.org/seata-go/pkg/util/bytes" | ||||
"seata.apache.org/seata-go/pkg/util/log" | "seata.apache.org/seata-go/pkg/util/log" | ||||
) | ) | ||||
@@ -49,6 +49,10 @@ type updateExecutor struct { | |||||
// NewUpdateExecutor get update executor | // NewUpdateExecutor get update executor | ||||
func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { | func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { | ||||
// Because update join cannot be clearly identified when SQL cannot be parsed | |||||
if parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil { | |||||
return NewUpdateJoinExecutor(parserCtx, execContent, hooks) | |||||
} | |||||
return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}} | return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}} | ||||
} | } | ||||
@@ -0,0 +1,348 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package at | |||||
import ( | |||||
"context" | |||||
"database/sql/driver" | |||||
"errors" | |||||
"io" | |||||
"reflect" | |||||
"strings" | |||||
"github.com/arana-db/parser/ast" | |||||
"github.com/arana-db/parser/format" | |||||
"github.com/arana-db/parser/model" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/datasource" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/exec" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/util" | |||||
"seata.apache.org/seata-go/pkg/util/bytes" | |||||
"seata.apache.org/seata-go/pkg/util/log" | |||||
) | |||||
const ( | |||||
LowerSupportGroupByPksVersion = "5.7.5" | |||||
) | |||||
// updateJoinExecutor execute update SQL | |||||
type updateJoinExecutor struct { | |||||
baseExecutor | |||||
parserCtx *types.ParseContext | |||||
execContext *types.ExecContext | |||||
isLowerSupportGroupByPksVersion bool | |||||
sqlMode string | |||||
tableAliasesMap map[string]string | |||||
} | |||||
// NewUpdateJoinExecutor get executor | |||||
func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { | |||||
minimumVersion, _ := util.ConvertDbVersion(LowerSupportGroupByPksVersion) | |||||
currentVersion, _ := util.ConvertDbVersion(execContent.DbVersion) | |||||
return &updateJoinExecutor{ | |||||
parserCtx: parserCtx, | |||||
execContext: execContent, | |||||
baseExecutor: baseExecutor{hooks: hooks}, | |||||
isLowerSupportGroupByPksVersion: currentVersion < minimumVersion, | |||||
tableAliasesMap: make(map[string]string, 0), | |||||
} | |||||
} | |||||
// ExecContext exec SQL, and generate before image and after image | |||||
func (u *updateJoinExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
u.beforeHooks(ctx, u.execContext) | |||||
defer func() { | |||||
u.afterHooks(ctx, u.execContext) | |||||
}() | |||||
if u.isAstStmtValid() { | |||||
u.tableAliasesMap = u.parseTableName(u.parserCtx.UpdateStmt.TableRefs.TableRefs) | |||||
} | |||||
beforeImages, err := u.beforeImage(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
afterImages, err := u.afterImage(ctx, beforeImages) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if len(afterImages) != len(beforeImages) { | |||||
return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") | |||||
} | |||||
u.execContext.TxCtx.RoundImages.AppendBeofreImages(beforeImages) | |||||
u.execContext.TxCtx.RoundImages.AppendAfterImages(afterImages) | |||||
return res, nil | |||||
} | |||||
func (u *updateJoinExecutor) isAstStmtValid() bool { | |||||
return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil && u.parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil | |||||
} | |||||
func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordImage, error) { | |||||
if !u.isAstStmtValid() { | |||||
return nil, nil | |||||
} | |||||
var recordImages []*types.RecordImage | |||||
for tbName, tableAliases := range u.tableAliasesMap { | |||||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tbName) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, tableAliases, u.execContext.NamedValues) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if selectSQL == "" { | |||||
log.Debugf("Skip unused table [{%s}] when build select sql by update sourceQuery", tbName) | |||||
continue | |||||
} | |||||
var image *types.RecordImage | |||||
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) | |||||
if err == nil { | |||||
image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) | |||||
} | |||||
if rowsi != nil { | |||||
if rowerr := rows.Close(); rowerr != nil { | |||||
log.Errorf("rows close fail, err:%v", rowerr) | |||||
return nil, rowerr | |||||
} | |||||
} | |||||
if err != nil { | |||||
// If one fail, all fails | |||||
return nil, err | |||||
} | |||||
lockKey := u.buildLockKey(image, *metaData) | |||||
u.execContext.TxCtx.LockKeys[lockKey] = struct{}{} | |||||
image.SQLType = u.parserCtx.SQLType | |||||
recordImages = append(recordImages, image) | |||||
} | |||||
return recordImages, nil | |||||
} | |||||
func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { | |||||
if !u.isAstStmtValid() { | |||||
return nil, nil | |||||
} | |||||
if len(beforeImages) == 0 { | |||||
return nil, errors.New("empty beforeImages") | |||||
} | |||||
var recordImages []*types.RecordImage | |||||
for _, beforeImage := range beforeImages { | |||||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, beforeImage.TableName) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
selectSQL, selectArgs, err := u.buildAfterImageSQL(ctx, *beforeImage, metaData, u.tableAliasesMap[beforeImage.TableName]) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
var image *types.RecordImage | |||||
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) | |||||
if err == nil { | |||||
image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) | |||||
} | |||||
if rowsi != nil { | |||||
if rowerr := rowsi.Close(); rowerr != nil { | |||||
log.Errorf("rows close fail, err:%v", rowerr) | |||||
return nil, rowerr | |||||
} | |||||
} | |||||
if err != nil { | |||||
// If one fail, all fails | |||||
return nil, err | |||||
} | |||||
image.SQLType = u.parserCtx.SQLType | |||||
recordImages = append(recordImages, image) | |||||
} | |||||
return recordImages, nil | |||||
} | |||||
// buildAfterImageSQL build the SQL to query before image data | |||||
func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) { | |||||
updateStmt := u.parserCtx.UpdateStmt | |||||
fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List) | |||||
if err != nil { | |||||
return "", nil, err | |||||
} | |||||
if len(fields) == 0 { | |||||
return "", nil, err | |||||
} | |||||
selStmt := ast.SelectStmt{ | |||||
SelectStmtOpts: &ast.SelectStmtOpts{}, | |||||
From: updateStmt.TableRefs, | |||||
Where: updateStmt.Where, | |||||
Fields: &ast.FieldList{Fields: fields}, | |||||
OrderBy: updateStmt.Order, | |||||
Limit: updateStmt.Limit, | |||||
TableHints: updateStmt.TableHints, | |||||
// maybe duplicate row for select join sql.remove duplicate row by 'group by' condition | |||||
GroupBy: &ast.GroupByClause{ | |||||
Items: u.buildGroupByClause(ctx, tableMeta.TableName, tableAliases, tableMeta.GetPrimaryKeyOnlyName(), fields), | |||||
}, | |||||
LockInfo: &ast.SelectLockInfo{ | |||||
LockType: ast.SelectLockForUpdate, | |||||
}, | |||||
} | |||||
b := bytes.NewByteBuffer([]byte{}) | |||||
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) | |||||
sql := string(b.Bytes()) | |||||
log.Infof("build select sql by update sourceQuery, sql {%s}", sql) | |||||
return sql, u.buildSelectArgs(&selStmt, args), nil | |||||
} | |||||
func (u *updateJoinExecutor) buildAfterImageSQL(ctx context.Context, beforeImage types.RecordImage, meta *types.TableMeta, tableAliases string) (string, []driver.NamedValue, error) { | |||||
if len(beforeImage.Rows) == 0 { | |||||
return "", nil, nil | |||||
} | |||||
fields, err := u.buildSelectFields(ctx, meta, tableAliases, u.parserCtx.UpdateStmt.List) | |||||
if err != nil { | |||||
return "", nil, err | |||||
} | |||||
if len(fields) == 0 { | |||||
return "", nil, err | |||||
} | |||||
updateStmt := u.parserCtx.UpdateStmt | |||||
selStmt := ast.SelectStmt{ | |||||
SelectStmtOpts: &ast.SelectStmtOpts{}, | |||||
From: updateStmt.TableRefs, | |||||
Where: updateStmt.Where, | |||||
Fields: &ast.FieldList{Fields: fields}, | |||||
OrderBy: updateStmt.Order, | |||||
Limit: updateStmt.Limit, | |||||
TableHints: updateStmt.TableHints, | |||||
// maybe duplicate row for select join sql.remove duplicate row by 'group by' condition | |||||
GroupBy: &ast.GroupByClause{ | |||||
Items: u.buildGroupByClause(ctx, meta.TableName, tableAliases, meta.GetPrimaryKeyOnlyName(), fields), | |||||
}, | |||||
} | |||||
b := bytes.NewByteBuffer([]byte{}) | |||||
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) | |||||
sql := string(b.Bytes()) | |||||
log.Infof("build select sql by update sourceQuery, sql {%s}", sql) | |||||
return sql, u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()), nil | |||||
} | |||||
func (u *updateJoinExecutor) parseTableName(joinMate *ast.Join) map[string]string { | |||||
tableNames := make(map[string]string, 0) | |||||
if item, ok := joinMate.Left.(*ast.Join); ok { | |||||
tableNames = u.parseTableName(item) | |||||
} else { | |||||
leftTableSource := joinMate.Left.(*ast.TableSource) | |||||
leftName := leftTableSource.Source.(*ast.TableName) | |||||
tableNames[leftName.Name.O] = leftTableSource.AsName.O | |||||
} | |||||
rightTableSource := joinMate.Right.(*ast.TableSource) | |||||
rightName := rightTableSource.Source.(*ast.TableName) | |||||
tableNames[rightName.Name.O] = rightTableSource.AsName.O | |||||
return tableNames | |||||
} | |||||
// build group by condition which used for removing duplicate row in select join sql | |||||
func (u *updateJoinExecutor) buildGroupByClause(ctx context.Context, tableName string, tableAliases string, pkColumns []string, allSelectColumns []*ast.SelectField) []*ast.ByItem { | |||||
var groupByPks = true | |||||
if tableAliases != "" { | |||||
tableName = tableAliases | |||||
} | |||||
//only pks group by is valid when db version >= 5.7.5 | |||||
if u.isLowerSupportGroupByPksVersion { | |||||
if u.sqlMode == "" { | |||||
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, "SELECT @@SQL_MODE", nil) | |||||
defer func() { | |||||
if rowsi != nil { | |||||
if rowerr := rowsi.Close(); rowerr != nil { | |||||
log.Errorf("rows close fail, err:%v", rowerr) | |||||
} | |||||
} | |||||
}() | |||||
if err != nil { | |||||
groupByPks = false | |||||
log.Warnf("determine group by pks or all columns error:%s", err) | |||||
} else { | |||||
// getString("@@SQL_MODE") | |||||
mode := make([]driver.Value, 1) | |||||
if err = rowsi.Next(mode); err != nil { | |||||
if err != io.EOF && len(mode) == 1 { | |||||
u.sqlMode = reflect.ValueOf(mode[0]).String() | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if strings.Contains(u.sqlMode, "ONLY_FULL_GROUP_BY") { | |||||
groupByPks = false | |||||
} | |||||
} | |||||
groupByColumns := make([]*ast.ByItem, 0) | |||||
if groupByPks { | |||||
for _, column := range pkColumns { | |||||
groupByColumns = append(groupByColumns, &ast.ByItem{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: &ast.ColumnName{ | |||||
Table: model.CIStr{ | |||||
O: tableName, | |||||
L: strings.ToLower(tableName), | |||||
}, | |||||
Name: model.CIStr{ | |||||
O: column, | |||||
L: strings.ToLower(column), | |||||
}, | |||||
}, | |||||
}, | |||||
}) | |||||
} | |||||
} else { | |||||
for _, column := range allSelectColumns { | |||||
groupByColumns = append(groupByColumns, &ast.ByItem{ | |||||
Expr: column.Expr, | |||||
}) | |||||
} | |||||
} | |||||
return groupByColumns | |||||
} |
@@ -0,0 +1,231 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package at | |||||
import ( | |||||
"context" | |||||
"database/sql/driver" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/undo" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/exec" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/parser" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/util" | |||||
_ "seata.apache.org/seata-go/pkg/util/log" | |||||
) | |||||
func TestBuildSelectSQLByUpdateJoin(t *testing.T) { | |||||
MetaDataMap := map[string]*types.TableMeta{ | |||||
"table1": { | |||||
TableName: "table1", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"id": { | |||||
IType: types.IndexTypePrimaryKey, | |||||
Columns: []types.ColumnMeta{ | |||||
{ColumnName: "id"}, | |||||
}, | |||||
}, | |||||
}, | |||||
Columns: map[string]types.ColumnMeta{ | |||||
"id": { | |||||
ColumnDef: nil, | |||||
ColumnName: "id", | |||||
}, | |||||
"name": { | |||||
ColumnDef: nil, | |||||
ColumnName: "name", | |||||
}, | |||||
"age": { | |||||
ColumnDef: nil, | |||||
ColumnName: "age", | |||||
}, | |||||
}, | |||||
ColumnNames: []string{"id", "name", "age"}, | |||||
}, | |||||
"table2": { | |||||
TableName: "table2", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"id": { | |||||
IType: types.IndexTypePrimaryKey, | |||||
Columns: []types.ColumnMeta{ | |||||
{ColumnName: "id"}, | |||||
}, | |||||
}, | |||||
}, | |||||
Columns: map[string]types.ColumnMeta{ | |||||
"id": { | |||||
ColumnDef: nil, | |||||
ColumnName: "id", | |||||
}, | |||||
"name": { | |||||
ColumnDef: nil, | |||||
ColumnName: "name", | |||||
}, | |||||
"age": { | |||||
ColumnDef: nil, | |||||
ColumnName: "age", | |||||
}, | |||||
"kk": { | |||||
ColumnDef: nil, | |||||
ColumnName: "kk", | |||||
}, | |||||
"addr": { | |||||
ColumnDef: nil, | |||||
ColumnName: "addr", | |||||
}, | |||||
}, | |||||
ColumnNames: []string{"id", "name", "age", "kk", "addr"}, | |||||
}, | |||||
"table3": { | |||||
TableName: "table3", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"id": { | |||||
IType: types.IndexTypePrimaryKey, | |||||
Columns: []types.ColumnMeta{ | |||||
{ColumnName: "id"}, | |||||
}, | |||||
}, | |||||
}, | |||||
Columns: map[string]types.ColumnMeta{ | |||||
"id": { | |||||
ColumnDef: nil, | |||||
ColumnName: "id", | |||||
}, | |||||
"age": { | |||||
ColumnDef: nil, | |||||
ColumnName: "age", | |||||
}, | |||||
}, | |||||
ColumnNames: []string{"id", "age"}, | |||||
}, | |||||
"table4": { | |||||
TableName: "table4", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"id": { | |||||
IType: types.IndexTypePrimaryKey, | |||||
Columns: []types.ColumnMeta{ | |||||
{ColumnName: "id"}, | |||||
}, | |||||
}, | |||||
}, | |||||
Columns: map[string]types.ColumnMeta{ | |||||
"id": { | |||||
ColumnDef: nil, | |||||
ColumnName: "id", | |||||
}, | |||||
"age": { | |||||
ColumnDef: nil, | |||||
ColumnName: "age", | |||||
}, | |||||
}, | |||||
ColumnNames: []string{"id", "age"}, | |||||
}, | |||||
} | |||||
undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) | |||||
tests := []struct { | |||||
name string | |||||
sourceQuery string | |||||
sourceQueryArgs []driver.Value | |||||
expectQuery map[string]string | |||||
expectQueryArgs []driver.Value | |||||
}{ | |||||
{ | |||||
sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id and t1.age=? set t1.name = 'WILL',t2.name = ?", | |||||
sourceQueryArgs: []driver.Value{18, "Jack"}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id AND t1.age=? GROUP BY t1.name,t1.id FOR UPDATE", | |||||
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id AND t1.age=? GROUP BY t2.name,t2.id FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{18}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update table1 AS t1 inner join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", | |||||
sourceQueryArgs: []driver.Value{1}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", | |||||
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{1}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update table1 AS t1 right join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", | |||||
sourceQueryArgs: []driver.Value{1}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", | |||||
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{1}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?", | |||||
sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{1, "Jack", 18, 28}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)", | |||||
sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{1, 18, 28}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?", | |||||
sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", | |||||
"table2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id inner join table3 t3 on t3.id = t2.id right join table4 t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30", | |||||
sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10}, | |||||
expectQuery: map[string]string{ | |||||
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE", | |||||
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE", | |||||
}, | |||||
expectQueryArgs: []driver.Value{1, 10}, | |||||
}, | |||||
} | |||||
for _, tt := range tests { | |||||
t.Run(tt.name, func(t *testing.T) { | |||||
c, err := parser.DoParser(tt.sourceQuery) | |||||
assert.Nil(t, err) | |||||
executor := NewUpdateJoinExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) | |||||
tableNames := executor.(*updateJoinExecutor).parseTableName(c.UpdateStmt.TableRefs.TableRefs) | |||||
for tbName, tableAliases := range tableNames { | |||||
query, args, err := executor.(*updateJoinExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap[tbName], tableAliases, util.ValueToNamedValue(tt.sourceQueryArgs)) | |||||
assert.Nil(t, err) | |||||
if query == "" { | |||||
continue | |||||
} | |||||
assert.Equal(t, tt.expectQuery[tbName], query) | |||||
assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) | |||||
} | |||||
}) | |||||
} | |||||
} |
@@ -22,12 +22,12 @@ import ( | |||||
"database/sql" | "database/sql" | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"fmt" | "fmt" | ||||
"io" | |||||
"strings" | |||||
"github.com/arana-db/parser/ast" | "github.com/arana-db/parser/ast" | ||||
"github.com/arana-db/parser/test_driver" | "github.com/arana-db/parser/test_driver" | ||||
gxsort "github.com/dubbogo/gost/sort" | gxsort "github.com/dubbogo/gost/sort" | ||||
"io" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/util" | |||||
"strings" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | "seata.apache.org/seata-go/pkg/datasource/sql/types" | ||||
) | ) | ||||
@@ -276,43 +276,5 @@ func (b *BasicUndoLogBuilder) buildLockKey(rows driver.Rows, meta types.TableMet | |||||
// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" | // the string as local key. the local key example(multi pk): "t_user:1_a,2_b" | ||||
func (b *BasicUndoLogBuilder) buildLockKey2(records *types.RecordImage, meta types.TableMeta) string { | func (b *BasicUndoLogBuilder) buildLockKey2(records *types.RecordImage, meta types.TableMeta) string { | ||||
var lockKeys bytes.Buffer | |||||
lockKeys.WriteString(meta.TableName) | |||||
lockKeys.WriteString(":") | |||||
keys := meta.GetPrimaryKeyOnlyName() | |||||
keyIndexMap := make(map[string]int, len(keys)) | |||||
for idx, columnName := range keys { | |||||
keyIndexMap[columnName] = idx | |||||
} | |||||
primaryKeyRows := make([][]interface{}, len(records.Rows)) | |||||
for i, row := range records.Rows { | |||||
primaryKeyValues := make([]interface{}, len(keys)) | |||||
for _, column := range row.Columns { | |||||
if idx, exist := keyIndexMap[column.ColumnName]; exist { | |||||
primaryKeyValues[idx] = column.Value | |||||
} | |||||
} | |||||
primaryKeyRows[i] = primaryKeyValues | |||||
} | |||||
for i, primaryKeyValues := range primaryKeyRows { | |||||
if i > 0 { | |||||
lockKeys.WriteString(",") | |||||
} | |||||
for j, pkVal := range primaryKeyValues { | |||||
if j > 0 { | |||||
lockKeys.WriteString("_") | |||||
} | |||||
if pkVal == nil { | |||||
continue | |||||
} | |||||
lockKeys.WriteString(fmt.Sprintf("%v", pkVal)) | |||||
} | |||||
} | |||||
return lockKeys.String() | |||||
return util.BuildLockKey(records, meta) | |||||
} | } |
@@ -69,6 +69,7 @@ func TestBuildLockKey(t *testing.T) { | |||||
} | } | ||||
columnsTwoPk := []types.ColumnMeta{columnID, columnUserId} | columnsTwoPk := []types.ColumnMeta{columnID, columnUserId} | ||||
columnsThreePk := []types.ColumnMeta{columnID, columnUserId, columnAge} | |||||
columnsMixPk := []types.ColumnMeta{columnName, columnAge} | columnsMixPk := []types.ColumnMeta{columnName, columnAge} | ||||
getColumnImage := func(columnName string, value interface{}) types.ColumnImage { | getColumnImage := func(columnName string, value interface{}) types.ColumnImage { | ||||
@@ -98,6 +99,24 @@ func TestBuildLockKey(t *testing.T) { | |||||
}, | }, | ||||
"test_name:1_one,2_two", | "test_name:1_one,2_two", | ||||
}, | }, | ||||
{ | |||||
"Three Primary Keys", | |||||
types.TableMeta{ | |||||
TableName: "test2_name", | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsThreePk}, | |||||
}, | |||||
}, | |||||
types.RecordImage{ | |||||
TableName: "test2_name", | |||||
Rows: []types.RowImage{ | |||||
{[]types.ColumnImage{getColumnImage("id", 1), getColumnImage("userId", "one"), getColumnImage("age", "11")}}, | |||||
{[]types.ColumnImage{getColumnImage("id", 2), getColumnImage("userId", "two"), getColumnImage("age", "22")}}, | |||||
{[]types.ColumnImage{getColumnImage("id", 3), getColumnImage("userId", "three"), getColumnImage("age", "33")}}, | |||||
}, | |||||
}, | |||||
"test2_name:1_one_11,2_two_22,3_three_33", | |||||
}, | |||||
{ | { | ||||
name: "Single Primary Key", | name: "Single Primary Key", | ||||
metaData: types.TableMeta{ | metaData: types.TableMeta{ | ||||
@@ -0,0 +1,75 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
// Copyright 2016 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
package util | |||||
import ( | |||||
"fmt" | |||||
"seata.apache.org/seata-go/pkg/datasource/sql/types" | |||||
"strings" | |||||
) | |||||
func BuildLockKey(records *types.RecordImage, meta types.TableMeta) string { | |||||
var lockKeys strings.Builder | |||||
type ColMapItem struct { | |||||
pkIndex int | |||||
colIndex int | |||||
} | |||||
lockKeys.WriteString(meta.TableName) | |||||
lockKeys.WriteString(":") | |||||
keys := meta.GetPrimaryKeyOnlyName() | |||||
keyIndexMap := make(map[string]int, len(keys)) | |||||
for idx, columnName := range keys { | |||||
keyIndexMap[columnName] = idx | |||||
} | |||||
columns := make([]ColMapItem, 0, len(keys)) | |||||
if len(records.Rows) > 0 { | |||||
for colIdx, column := range records.Rows[0].Columns { | |||||
if pkIdx, ok := keyIndexMap[column.ColumnName]; ok { | |||||
columns = append(columns, ColMapItem{pkIndex: pkIdx, colIndex: colIdx}) | |||||
} | |||||
} | |||||
for i, row := range records.Rows { | |||||
if i > 0 { | |||||
lockKeys.WriteString(",") | |||||
} | |||||
primaryKeyValues := make([]interface{}, len(keys)) | |||||
for _, mp := range columns { | |||||
if mp.colIndex < len(row.Columns) { | |||||
primaryKeyValues[mp.pkIndex] = row.Columns[mp.colIndex].Value | |||||
} | |||||
} | |||||
for j, pkVal := range primaryKeyValues { | |||||
if j > 0 { | |||||
lockKeys.WriteString("_") | |||||
} | |||||
if pkVal == nil { | |||||
continue | |||||
} | |||||
lockKeys.WriteString(fmt.Sprintf("%v", pkVal)) | |||||
} | |||||
} | |||||
} | |||||
return lockKeys.String() | |||||
} |