Compare commits

...

3 Commits

Author SHA1 Message Date
  Aster Zephyr 08cdcf2166
the ability to automatically run unit tests after creating a pull request. (#764) 3 months ago
  panlei-coder 15d5047827
fix: Solve the conflict problem of introducing multiple versions of knadh (#772) 3 months ago
  Aster Zephyr 8e56d22d7d
bugfix: error image when use null value as image query condition in insert on duplicate #704 (#725) 3 months ago
16 changed files with 1295 additions and 154 deletions
Unified View
  1. +1
    -7
      .github/workflows/build.yml
  2. +2
    -0
      .github/workflows/golangci-lint.yml
  3. +78
    -0
      .github/workflows/unit-test.yml
  4. +0
    -1
      go.mod
  5. +1
    -1
      pkg/client/config.go
  6. +94
    -31
      pkg/datasource/sql/exec/at/base_executor.go
  7. +212
    -0
      pkg/datasource/sql/exec/at/base_executor_test.go
  8. +6
    -2
      pkg/datasource/sql/exec/at/update_executor.go
  9. +348
    -0
      pkg/datasource/sql/exec/at/update_join_executor.go
  10. +231
    -0
      pkg/datasource/sql/exec/at/update_join_executor_test.go
  11. +7
    -4
      pkg/datasource/sql/types/image.go
  12. +4
    -42
      pkg/datasource/sql/undo/builder/basic_undo_log_builder.go
  13. +19
    -0
      pkg/datasource/sql/undo/builder/basic_undo_log_builder_test.go
  14. +154
    -66
      pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go
  15. +63
    -0
      pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go
  16. +75
    -0
      pkg/datasource/sql/util/lockkey.go

+ 1
- 7
.github/workflows/build.yml View File

@@ -17,7 +17,7 @@


# This is a workflow to help you test the unit case and show codecov # This is a workflow to help you test the unit case and show codecov


name: "build and codecov"
name: "build"


on: on:
push: push:
@@ -60,9 +60,3 @@ jobs:


- name: "run go build" - name: "run go build"
run: go build -v ./... run: go build -v ./...

- name: "run go test and out codecov"
run: go test -v ./... -race -coverprofile=coverage.out -covermode=atomic

- name: "upload coverage"
uses: codecov/codecov-action@v3

+ 2
- 0
.github/workflows/golangci-lint.yml View File

@@ -58,3 +58,5 @@ jobs:
version: v1.51.0 version: v1.51.0
args: --timeout=10m args: --timeout=10m
skip-go-installation: true skip-go-installation: true
skip-cache: true
skip-pkg-cache: true

+ 78
- 0
.github/workflows/unit-test.yml View File

@@ -0,0 +1,78 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

name: "Unit Test"

on:
push:
branches: [ master ]
pull_request:
branches: [ "*" ]
types: [opened, synchronize, reopened]

permissions:
contents: read

jobs:
unit-test:
name: Unit Test
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
golang:
- 1.18

steps:
- name: "Set up Go"
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.golang }}

- name: "Checkout code"
uses: actions/checkout@v3
with:
submodules: true

- name: "Cache dependencies"
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: "${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}"
restore-keys: |
"${{ runner.os }}-go-"

- name: Shutdown default mysql
run: sudo service mysql stop


- name: "Run Unit Tests"
run: |
echo "=== Starting Unit Tests ==="
go test -v ./... -race -coverprofile=coverage.txt -covermode=atomic -timeout 10m
if [ $? -ne 0 ]; then
echo "❌ Unit tests failed"
exit 1
fi
echo "✅ Unit tests completed successfully"

- name: "Archive test results"
uses: actions/upload-artifact@v3
with:
name: test-results
path: coverage.txt
retention-days: 7

+ 0
- 1
go.mod View File

@@ -39,7 +39,6 @@ require (


require ( require (
github.com/knadh/koanf v1.5.0 github.com/knadh/koanf v1.5.0
github.com/knadh/koanf/v2 v2.1.2
) )


require ( require (


+ 1
- 1
pkg/client/config.go View File

@@ -26,11 +26,11 @@ import (
"runtime" "runtime"
"strings" "strings"


"github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/json" "github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/parsers/toml"
"github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/rawbytes" "github.com/knadh/koanf/providers/rawbytes"
koanf "github.com/knadh/koanf/v2"


"seata.apache.org/seata-go/pkg/discovery" "seata.apache.org/seata-go/pkg/discovery"




+ 94
- 31
pkg/datasource/sql/exec/at/base_executor.go View File

@@ -18,20 +18,21 @@
package at package at


import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
"strings" "strings"


"github.com/arana-db/parser/ast" "github.com/arana-db/parser/ast"
"github.com/arana-db/parser/model"
"github.com/arana-db/parser/test_driver" "github.com/arana-db/parser/test_driver"
gxsort "github.com/dubbogo/gost/sort" gxsort "github.com/dubbogo/gost/sort"
"github.com/pkg/errors"


"seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/exec"
"seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
"seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/datasource/sql/util"
"seata.apache.org/seata-go/pkg/util/reflectx" "seata.apache.org/seata-go/pkg/util/reflectx"
) )
@@ -98,7 +99,13 @@ func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.Named
selectArgs = make([]driver.NamedValue, 0) selectArgs = make([]driver.NamedValue, 0)
) )


b.traversalArgs(stmt.From.TableRefs, &selectArgsIndexs)
b.traversalArgs(stmt.Where, &selectArgsIndexs) b.traversalArgs(stmt.Where, &selectArgsIndexs)
if stmt.GroupBy != nil {
for _, item := range stmt.GroupBy.Items {
b.traversalArgs(item, &selectArgsIndexs)
}
}
if stmt.OrderBy != nil { if stmt.OrderBy != nil {
for _, item := range stmt.OrderBy.Items { for _, item := range stmt.OrderBy.Items {
b.traversalArgs(item, &selectArgsIndexs) b.traversalArgs(item, &selectArgsIndexs)
@@ -143,6 +150,16 @@ func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) {
b.traversalArgs(exprs[i], argsIndex) b.traversalArgs(exprs[i], argsIndex)
} }
break break
case *ast.Join:
exprs := node.(*ast.Join)
b.traversalArgs(exprs.Left, argsIndex)
if exprs.Right != nil {
b.traversalArgs(exprs.Right, argsIndex)
}
if exprs.On != nil {
b.traversalArgs(exprs.On.Expr, argsIndex)
}
break
case *test_driver.ParamMarkerExpr: case *test_driver.ParamMarkerExpr:
*argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order))
break break
@@ -230,6 +247,64 @@ func (b *baseExecutor) containsPKByName(meta *types.TableMeta, columns []string)
return matchCounter == len(pkColumnNameList) return matchCounter == len(pkColumnNameList)
} }


func (u *baseExecutor) buildSelectFields(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, inUseFields []*ast.Assignment) ([]*ast.SelectField, error) {
fields := make([]*ast.SelectField, 0, len(inUseFields))

tableName := tableAliases
if tableAliases == "" {
tableName = tableMeta.TableName
}
if undo.UndoConfig.OnlyCareUpdateColumns {
for _, column := range inUseFields {
tn := column.Column.Table.O
if tn != "" && tn != tableName {
continue
}

fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: column.Column,
},
})
}

if len(fields) == 0 {
return fields, nil
}

// select indexes columns
for _, columnName := range tableMeta.GetPrimaryKeyOnlyName() {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Table: model.CIStr{
O: tableName,
L: tableName,
},
Name: model.CIStr{
O: columnName,
L: columnName,
},
},
},
})
}
} else {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.CIStr{
O: "*",
L: "*",
},
},
},
})
}

return fields, nil
}

func getSqlNullValue(value interface{}) interface{} { func getSqlNullValue(value interface{}) interface{} {
if value == nil { if value == nil {
return nil return nil
@@ -359,37 +434,25 @@ func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string)


// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" // the string as local key. the local key example(multi pk): "t_user:1_a,2_b"
func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string {
var (
lockKeys bytes.Buffer
filedSequence int
)
lockKeys.WriteString(meta.TableName)
lockKeys.WriteString(":")
return util.BuildLockKey(records, meta)
}


keys := meta.GetPrimaryKeyOnlyName()
func (b *baseExecutor) rowsPrepare(ctx context.Context, conn driver.Conn, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) {
var queryer driver.Queryer


for _, row := range records.Rows {
if filedSequence > 0 {
lockKeys.WriteString(",")
}
pkSplitIndex := 0
for _, column := range row.Columns {
var hasKeyColumn bool
for _, key := range keys {
if column.ColumnName == key {
hasKeyColumn = true
if pkSplitIndex > 0 {
lockKeys.WriteString("_")
}
lockKeys.WriteString(fmt.Sprintf("%v", column.Value))
pkSplitIndex++
}
}
if hasKeyColumn {
filedSequence++
}
}
queryerContext, ok := conn.(driver.QueryerContext)
if !ok {
queryer, ok = conn.(driver.Queryer)
} }
if ok {
var err error
rows, err = util.CtxDriverQuery(ctx, queryerContext, queryer, selectSQL, selectArgs)


return lockKeys.String()
if err != nil {
return nil, err
}
} else {
return nil, errors.New("target conn should been driver.QueryerContext or driver.Queryer")
}
return rows, nil
} }

+ 212
- 0
pkg/datasource/sql/exec/at/base_executor_test.go View File

@@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package at

import (
"github.com/stretchr/testify/assert"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"testing"
)

func TestBaseExecBuildLockKey(t *testing.T) {
var exec baseExecutor

columnID := types.ColumnMeta{
ColumnName: "id",
}
columnUserId := types.ColumnMeta{
ColumnName: "userId",
}
columnName := types.ColumnMeta{
ColumnName: "name",
}
columnAge := types.ColumnMeta{
ColumnName: "age",
}
columnNonExistent := types.ColumnMeta{
ColumnName: "non_existent",
}

columnsTwoPk := []types.ColumnMeta{columnID, columnUserId}
columnsThreePk := []types.ColumnMeta{columnID, columnUserId, columnAge}
columnsMixPk := []types.ColumnMeta{columnName, columnAge}

getColumnImage := func(columnName string, value interface{}) types.ColumnImage {
return types.ColumnImage{KeyType: types.IndexTypePrimaryKey, ColumnName: columnName, Value: value}
}

tests := []struct {
name string
metaData types.TableMeta
records types.RecordImage
expected string
}{
{
"Two Primary Keys",
types.TableMeta{
TableName: "test_name",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsTwoPk},
},
},
types.RecordImage{
TableName: "test_name",
Rows: []types.RowImage{
{[]types.ColumnImage{getColumnImage("id", 1), getColumnImage("userId", "user1")}},
{[]types.ColumnImage{getColumnImage("id", 2), getColumnImage("userId", "user2")}},
},
},
"test_name:1_user1,2_user2",
},
{
"Three Primary Keys",
types.TableMeta{
TableName: "test2_name",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsThreePk},
},
},
types.RecordImage{
TableName: "test2_name",
Rows: []types.RowImage{
{[]types.ColumnImage{getColumnImage("id", 1), getColumnImage("userId", "one"), getColumnImage("age", "11")}},
{[]types.ColumnImage{getColumnImage("id", 2), getColumnImage("userId", "two"), getColumnImage("age", "22")}},
{[]types.ColumnImage{getColumnImage("id", 3), getColumnImage("userId", "three"), getColumnImage("age", "33")}},
},
},
"test2_name:1_one_11,2_two_22,3_three_33",
},
{
name: "Single Primary Key",
metaData: types.TableMeta{
TableName: "single_key",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}},
},
},
records: types.RecordImage{
TableName: "single_key",
Rows: []types.RowImage{
{Columns: []types.ColumnImage{getColumnImage("id", 100)}},
},
},
expected: "single_key:100",
},
{
name: "Mixed Type Keys",
metaData: types.TableMeta{
TableName: "mixed_key",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsMixPk},
},
},
records: types.RecordImage{
TableName: "mixed_key",
Rows: []types.RowImage{
{Columns: []types.ColumnImage{getColumnImage("name", "mike"), getColumnImage("age", 25)}},
},
},
expected: "mixed_key:mike_25",
},
{
name: "Empty Records",
metaData: types.TableMeta{
TableName: "empty",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}},
},
},
records: types.RecordImage{TableName: "empty"},
expected: "empty:",
},
{
name: "Special Characters",
metaData: types.TableMeta{
TableName: "special",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}},
},
},
records: types.RecordImage{
TableName: "special",
Rows: []types.RowImage{
{Columns: []types.ColumnImage{getColumnImage("id", "A,b_c")}},
},
},
expected: "special:A,b_c",
},
{
name: "Non-existent Key Name",
metaData: types.TableMeta{
TableName: "error_key",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnNonExistent}},
},
},
records: types.RecordImage{
TableName: "error_key",
Rows: []types.RowImage{
{Columns: []types.ColumnImage{getColumnImage("id", 1)}},
},
},
expected: "error_key:",
},
{
name: "Multiple Rows With Nil PK Value",
metaData: types.TableMeta{
TableName: "nil_pk",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnID}},
},
},
records: types.RecordImage{
TableName: "nil_pk",
Rows: []types.RowImage{
{Columns: []types.ColumnImage{getColumnImage("id", nil)}},
{Columns: []types.ColumnImage{getColumnImage("id", 123)}},
{Columns: []types.ColumnImage{getColumnImage("id", nil)}},
},
},
expected: "nil_pk:,123,",
},
{
name: "PK As Bool And Float",
metaData: types.TableMeta{
TableName: "type_pk",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: []types.ColumnMeta{columnName, columnAge}},
},
},
records: types.RecordImage{
TableName: "type_pk",
Rows: []types.RowImage{
{Columns: []types.ColumnImage{getColumnImage("name", true), getColumnImage("age", 3.14)}},
{Columns: []types.ColumnImage{getColumnImage("name", false), getColumnImage("age", 0.0)}},
},
},
expected: "type_pk:true_3.14,false_0",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lockKeys := exec.buildLockKey(&tt.records, tt.metaData)
assert.Equal(t, tt.expected, lockKeys)
})
}
}

+ 6
- 2
pkg/datasource/sql/exec/at/update_executor.go View File

@@ -21,17 +21,17 @@ import (
"context" "context"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"github.com/arana-db/parser/model"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"strings" "strings"


"github.com/arana-db/parser/ast" "github.com/arana-db/parser/ast"
"github.com/arana-db/parser/format" "github.com/arana-db/parser/format"
"github.com/arana-db/parser/model"


"seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/datasource"
"seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/exec"
"seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/undo"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"seata.apache.org/seata-go/pkg/util/bytes" "seata.apache.org/seata-go/pkg/util/bytes"
"seata.apache.org/seata-go/pkg/util/log" "seata.apache.org/seata-go/pkg/util/log"
) )
@@ -49,6 +49,10 @@ type updateExecutor struct {


// NewUpdateExecutor get update executor // NewUpdateExecutor get update executor
func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
// Because update join cannot be clearly identified when SQL cannot be parsed
if parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil {
return NewUpdateJoinExecutor(parserCtx, execContent, hooks)
}
return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}} return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}
} }




+ 348
- 0
pkg/datasource/sql/exec/at/update_join_executor.go View File

@@ -0,0 +1,348 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package at

import (
"context"
"database/sql/driver"
"errors"
"io"
"reflect"
"strings"

"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/format"
"github.com/arana-db/parser/model"

"seata.apache.org/seata-go/pkg/datasource/sql/datasource"
"seata.apache.org/seata-go/pkg/datasource/sql/exec"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"seata.apache.org/seata-go/pkg/util/bytes"
"seata.apache.org/seata-go/pkg/util/log"
)

const (
LowerSupportGroupByPksVersion = "5.7.5"
)

// updateJoinExecutor execute update SQL
type updateJoinExecutor struct {
baseExecutor
parserCtx *types.ParseContext
execContext *types.ExecContext
isLowerSupportGroupByPksVersion bool
sqlMode string
tableAliasesMap map[string]string
}

// NewUpdateJoinExecutor get executor
func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
minimumVersion, _ := util.ConvertDbVersion(LowerSupportGroupByPksVersion)
currentVersion, _ := util.ConvertDbVersion(execContent.DbVersion)
return &updateJoinExecutor{
parserCtx: parserCtx,
execContext: execContent,
baseExecutor: baseExecutor{hooks: hooks},
isLowerSupportGroupByPksVersion: currentVersion < minimumVersion,
tableAliasesMap: make(map[string]string, 0),
}
}

// ExecContext exec SQL, and generate before image and after image
func (u *updateJoinExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
u.beforeHooks(ctx, u.execContext)
defer func() {
u.afterHooks(ctx, u.execContext)
}()

if u.isAstStmtValid() {
u.tableAliasesMap = u.parseTableName(u.parserCtx.UpdateStmt.TableRefs.TableRefs)
}

beforeImages, err := u.beforeImage(ctx)
if err != nil {
return nil, err
}

res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues)
if err != nil {
return nil, err
}

afterImages, err := u.afterImage(ctx, beforeImages)
if err != nil {
return nil, err
}

if len(afterImages) != len(beforeImages) {
return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.")
}

u.execContext.TxCtx.RoundImages.AppendBeofreImages(beforeImages)
u.execContext.TxCtx.RoundImages.AppendAfterImages(afterImages)

return res, nil
}

func (u *updateJoinExecutor) isAstStmtValid() bool {
return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil && u.parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil
}

func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordImage, error) {
if !u.isAstStmtValid() {
return nil, nil
}

var recordImages []*types.RecordImage

for tbName, tableAliases := range u.tableAliasesMap {
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tbName)
if err != nil {
return nil, err
}
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, tableAliases, u.execContext.NamedValues)
if err != nil {
return nil, err
}
if selectSQL == "" {
log.Debugf("Skip unused table [{%s}] when build select sql by update sourceQuery", tbName)
continue
}

var image *types.RecordImage
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs)
if err == nil {
image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
}
if rowsi != nil {
if rowerr := rows.Close(); rowerr != nil {
log.Errorf("rows close fail, err:%v", rowerr)
return nil, rowerr
}
}
if err != nil {
// If one fail, all fails
return nil, err
}

lockKey := u.buildLockKey(image, *metaData)
u.execContext.TxCtx.LockKeys[lockKey] = struct{}{}
image.SQLType = u.parserCtx.SQLType

recordImages = append(recordImages, image)
}

return recordImages, nil
}

func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
if !u.isAstStmtValid() {
return nil, nil
}

if len(beforeImages) == 0 {
return nil, errors.New("empty beforeImages")
}

var recordImages []*types.RecordImage
for _, beforeImage := range beforeImages {
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, beforeImage.TableName)
if err != nil {
return nil, err
}

selectSQL, selectArgs, err := u.buildAfterImageSQL(ctx, *beforeImage, metaData, u.tableAliasesMap[beforeImage.TableName])
if err != nil {
return nil, err
}

var image *types.RecordImage
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs)
if err == nil {
image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
}
if rowsi != nil {
if rowerr := rowsi.Close(); rowerr != nil {
log.Errorf("rows close fail, err:%v", rowerr)
return nil, rowerr
}
}
if err != nil {
// If one fail, all fails
return nil, err
}

image.SQLType = u.parserCtx.SQLType
recordImages = append(recordImages, image)
}

return recordImages, nil
}

// buildAfterImageSQL build the SQL to query before image data
func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) {
updateStmt := u.parserCtx.UpdateStmt
fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List)
if err != nil {
return "", nil, err
}
if len(fields) == 0 {
return "", nil, err
}

selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
// maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
GroupBy: &ast.GroupByClause{
Items: u.buildGroupByClause(ctx, tableMeta.TableName, tableAliases, tableMeta.GetPrimaryKeyOnlyName(), fields),
},
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}

b := bytes.NewByteBuffer([]byte{})
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by update sourceQuery, sql {%s}", sql)

return sql, u.buildSelectArgs(&selStmt, args), nil
}

func (u *updateJoinExecutor) buildAfterImageSQL(ctx context.Context, beforeImage types.RecordImage, meta *types.TableMeta, tableAliases string) (string, []driver.NamedValue, error) {
if len(beforeImage.Rows) == 0 {
return "", nil, nil
}

fields, err := u.buildSelectFields(ctx, meta, tableAliases, u.parserCtx.UpdateStmt.List)
if err != nil {
return "", nil, err
}
if len(fields) == 0 {
return "", nil, err
}

updateStmt := u.parserCtx.UpdateStmt
selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
// maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
GroupBy: &ast.GroupByClause{
Items: u.buildGroupByClause(ctx, meta.TableName, tableAliases, meta.GetPrimaryKeyOnlyName(), fields),
},
}

b := bytes.NewByteBuffer([]byte{})
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by update sourceQuery, sql {%s}", sql)

return sql, u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()), nil
}

func (u *updateJoinExecutor) parseTableName(joinMate *ast.Join) map[string]string {
tableNames := make(map[string]string, 0)
if item, ok := joinMate.Left.(*ast.Join); ok {
tableNames = u.parseTableName(item)
} else {
leftTableSource := joinMate.Left.(*ast.TableSource)
leftName := leftTableSource.Source.(*ast.TableName)
tableNames[leftName.Name.O] = leftTableSource.AsName.O
}

rightTableSource := joinMate.Right.(*ast.TableSource)
rightName := rightTableSource.Source.(*ast.TableName)
tableNames[rightName.Name.O] = rightTableSource.AsName.O
return tableNames
}

// build group by condition which used for removing duplicate row in select join sql
func (u *updateJoinExecutor) buildGroupByClause(ctx context.Context, tableName string, tableAliases string, pkColumns []string, allSelectColumns []*ast.SelectField) []*ast.ByItem {
var groupByPks = true
if tableAliases != "" {
tableName = tableAliases
}
//only pks group by is valid when db version >= 5.7.5
if u.isLowerSupportGroupByPksVersion {
if u.sqlMode == "" {
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, "SELECT @@SQL_MODE", nil)
defer func() {
if rowsi != nil {
if rowerr := rowsi.Close(); rowerr != nil {
log.Errorf("rows close fail, err:%v", rowerr)
}
}
}()
if err != nil {
groupByPks = false
log.Warnf("determine group by pks or all columns error:%s", err)
} else {
// getString("@@SQL_MODE")
mode := make([]driver.Value, 1)
if err = rowsi.Next(mode); err != nil {
if err != io.EOF && len(mode) == 1 {
u.sqlMode = reflect.ValueOf(mode[0]).String()
}
}
}
}

if strings.Contains(u.sqlMode, "ONLY_FULL_GROUP_BY") {
groupByPks = false
}
}

groupByColumns := make([]*ast.ByItem, 0)
if groupByPks {
for _, column := range pkColumns {
groupByColumns = append(groupByColumns, &ast.ByItem{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Table: model.CIStr{
O: tableName,
L: strings.ToLower(tableName),
},
Name: model.CIStr{
O: column,
L: strings.ToLower(column),
},
},
},
})
}
} else {
for _, column := range allSelectColumns {
groupByColumns = append(groupByColumns, &ast.ByItem{
Expr: column.Expr,
})
}
}
return groupByColumns
}

+ 231
- 0
pkg/datasource/sql/exec/at/update_join_executor_test.go View File

@@ -0,0 +1,231 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package at

import (
"context"
"database/sql/driver"
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
"testing"

"github.com/stretchr/testify/assert"

"seata.apache.org/seata-go/pkg/datasource/sql/exec"
"seata.apache.org/seata-go/pkg/datasource/sql/parser"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
_ "seata.apache.org/seata-go/pkg/util/log"
)

func TestBuildSelectSQLByUpdateJoin(t *testing.T) {
MetaDataMap := map[string]*types.TableMeta{
"table1": {
TableName: "table1",
Indexs: map[string]types.IndexMeta{
"id": {
IType: types.IndexTypePrimaryKey,
Columns: []types.ColumnMeta{
{ColumnName: "id"},
},
},
},
Columns: map[string]types.ColumnMeta{
"id": {
ColumnDef: nil,
ColumnName: "id",
},
"name": {
ColumnDef: nil,
ColumnName: "name",
},
"age": {
ColumnDef: nil,
ColumnName: "age",
},
},
ColumnNames: []string{"id", "name", "age"},
},
"table2": {
TableName: "table2",
Indexs: map[string]types.IndexMeta{
"id": {
IType: types.IndexTypePrimaryKey,
Columns: []types.ColumnMeta{
{ColumnName: "id"},
},
},
},
Columns: map[string]types.ColumnMeta{
"id": {
ColumnDef: nil,
ColumnName: "id",
},
"name": {
ColumnDef: nil,
ColumnName: "name",
},
"age": {
ColumnDef: nil,
ColumnName: "age",
},
"kk": {
ColumnDef: nil,
ColumnName: "kk",
},
"addr": {
ColumnDef: nil,
ColumnName: "addr",
},
},
ColumnNames: []string{"id", "name", "age", "kk", "addr"},
},
"table3": {
TableName: "table3",
Indexs: map[string]types.IndexMeta{
"id": {
IType: types.IndexTypePrimaryKey,
Columns: []types.ColumnMeta{
{ColumnName: "id"},
},
},
},
Columns: map[string]types.ColumnMeta{
"id": {
ColumnDef: nil,
ColumnName: "id",
},
"age": {
ColumnDef: nil,
ColumnName: "age",
},
},
ColumnNames: []string{"id", "age"},
},
"table4": {
TableName: "table4",
Indexs: map[string]types.IndexMeta{
"id": {
IType: types.IndexTypePrimaryKey,
Columns: []types.ColumnMeta{
{ColumnName: "id"},
},
},
},
Columns: map[string]types.ColumnMeta{
"id": {
ColumnDef: nil,
ColumnName: "id",
},
"age": {
ColumnDef: nil,
ColumnName: "age",
},
},
ColumnNames: []string{"id", "age"},
},
}

undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true})

tests := []struct {
name string
sourceQuery string
sourceQueryArgs []driver.Value
expectQuery map[string]string
expectQueryArgs []driver.Value
}{
{
sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id and t1.age=? set t1.name = 'WILL',t2.name = ?",
sourceQueryArgs: []driver.Value{18, "Jack"},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id AND t1.age=? GROUP BY t1.name,t1.id FOR UPDATE",
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id AND t1.age=? GROUP BY t2.name,t2.id FOR UPDATE",
},
expectQueryArgs: []driver.Value{18},
},
{
sourceQuery: "update table1 AS t1 inner join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?",
sourceQueryArgs: []driver.Value{1},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE",
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE",
},
expectQueryArgs: []driver.Value{1},
},
{
sourceQuery: "update table1 AS t1 right join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?",
sourceQueryArgs: []driver.Value{1},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE",
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE",
},
expectQueryArgs: []driver.Value{1},
},
{
sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?",
sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE",
},
expectQueryArgs: []driver.Value{1, "Jack", 18, 28},
},
{
sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)",
sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE",
},
expectQueryArgs: []driver.Value{1, 18, 28},
},
{
sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?",
sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE",
"table2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE",
},
expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2},
},
{
sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id inner join table3 t3 on t3.id = t2.id right join table4 t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30",
sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10},
expectQuery: map[string]string{
"table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE",
"table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE",
},
expectQueryArgs: []driver.Value{1, 10},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := parser.DoParser(tt.sourceQuery)
assert.Nil(t, err)
executor := NewUpdateJoinExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{})
tableNames := executor.(*updateJoinExecutor).parseTableName(c.UpdateStmt.TableRefs.TableRefs)
for tbName, tableAliases := range tableNames {
query, args, err := executor.(*updateJoinExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap[tbName], tableAliases, util.ValueToNamedValue(tt.sourceQueryArgs))
assert.Nil(t, err)
if query == "" {
continue
}
assert.Equal(t, tt.expectQuery[tbName], query)
assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args))
}
})
}
}

+ 7
- 4
pkg/datasource/sql/types/image.go View File

@@ -18,6 +18,7 @@
package types package types


import ( import (
"database/sql/driver"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"reflect" "reflect"
@@ -117,14 +118,16 @@ type RecordImage struct {
// Rows data row // Rows data row
Rows []RowImage `json:"rows"` Rows []RowImage `json:"rows"`
// TableMeta table information schema // TableMeta table information schema
TableMeta *TableMeta `json:"-"`
TableMeta *TableMeta `json:"-"`
PrimaryKeyMap map[string][]driver.Value `json:"primaryKeyMap,omitempty"`
} }


func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage { func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage {
return &RecordImage{ return &RecordImage{
TableName: tableMeta.TableName,
TableMeta: tableMeta,
SQLType: sqlType,
TableName: tableMeta.TableName,
TableMeta: tableMeta,
SQLType: sqlType,
PrimaryKeyMap: make(map[string][]driver.Value),
} }
} }




+ 4
- 42
pkg/datasource/sql/undo/builder/basic_undo_log_builder.go View File

@@ -22,12 +22,12 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"io"
"strings"

"github.com/arana-db/parser/ast" "github.com/arana-db/parser/ast"
"github.com/arana-db/parser/test_driver" "github.com/arana-db/parser/test_driver"
gxsort "github.com/dubbogo/gost/sort" gxsort "github.com/dubbogo/gost/sort"
"io"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"strings"


"seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/types"
) )
@@ -276,43 +276,5 @@ func (b *BasicUndoLogBuilder) buildLockKey(rows driver.Rows, meta types.TableMet


// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" // the string as local key. the local key example(multi pk): "t_user:1_a,2_b"
func (b *BasicUndoLogBuilder) buildLockKey2(records *types.RecordImage, meta types.TableMeta) string { func (b *BasicUndoLogBuilder) buildLockKey2(records *types.RecordImage, meta types.TableMeta) string {
var lockKeys bytes.Buffer
lockKeys.WriteString(meta.TableName)
lockKeys.WriteString(":")

keys := meta.GetPrimaryKeyOnlyName()
keyIndexMap := make(map[string]int, len(keys))

for idx, columnName := range keys {
keyIndexMap[columnName] = idx
}

primaryKeyRows := make([][]interface{}, len(records.Rows))

for i, row := range records.Rows {
primaryKeyValues := make([]interface{}, len(keys))
for _, column := range row.Columns {
if idx, exist := keyIndexMap[column.ColumnName]; exist {
primaryKeyValues[idx] = column.Value
}
}
primaryKeyRows[i] = primaryKeyValues
}

for i, primaryKeyValues := range primaryKeyRows {
if i > 0 {
lockKeys.WriteString(",")
}
for j, pkVal := range primaryKeyValues {
if j > 0 {
lockKeys.WriteString("_")
}
if pkVal == nil {
continue
}
lockKeys.WriteString(fmt.Sprintf("%v", pkVal))
}
}

return lockKeys.String()
return util.BuildLockKey(records, meta)
} }

+ 19
- 0
pkg/datasource/sql/undo/builder/basic_undo_log_builder_test.go View File

@@ -69,6 +69,7 @@ func TestBuildLockKey(t *testing.T) {
} }


columnsTwoPk := []types.ColumnMeta{columnID, columnUserId} columnsTwoPk := []types.ColumnMeta{columnID, columnUserId}
columnsThreePk := []types.ColumnMeta{columnID, columnUserId, columnAge}
columnsMixPk := []types.ColumnMeta{columnName, columnAge} columnsMixPk := []types.ColumnMeta{columnName, columnAge}


getColumnImage := func(columnName string, value interface{}) types.ColumnImage { getColumnImage := func(columnName string, value interface{}) types.ColumnImage {
@@ -98,6 +99,24 @@ func TestBuildLockKey(t *testing.T) {
}, },
"test_name:1_one,2_two", "test_name:1_one,2_two",
}, },
{
"Three Primary Keys",
types.TableMeta{
TableName: "test2_name",
Indexs: map[string]types.IndexMeta{
"PRIMARY_KEY": {IType: types.IndexTypePrimaryKey, Columns: columnsThreePk},
},
},
types.RecordImage{
TableName: "test2_name",
Rows: []types.RowImage{
{[]types.ColumnImage{getColumnImage("id", 1), getColumnImage("userId", "one"), getColumnImage("age", "11")}},
{[]types.ColumnImage{getColumnImage("id", 2), getColumnImage("userId", "two"), getColumnImage("age", "22")}},
{[]types.ColumnImage{getColumnImage("id", 3), getColumnImage("userId", "three"), getColumnImage("age", "33")}},
},
},
"test2_name:1_one_11,2_two_22,3_three_33",
},
{ {
name: "Single Primary Key", name: "Single Primary Key",
metaData: types.TableMeta{ metaData: types.TableMeta{


+ 154
- 66
pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go View File

@@ -97,68 +97,108 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a
if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil {
return "", nil, err return "", nil, err
} }
var selectArgs []driver.Value
u.BeforeImageSqlPrimaryKeys = make(map[string]bool, len(metaData.Indexs))
pkIndexMap := u.getPkIndex(insertStmt, metaData) pkIndexMap := u.getPkIndex(insertStmt, metaData)
var pkIndexArray []int var pkIndexArray []int
for _, val := range pkIndexMap { for _, val := range pkIndexMap {
tmpVal := val
pkIndexArray = append(pkIndexArray, tmpVal)
pkIndexArray = append(pkIndexArray, val)
} }
insertRows, err := getInsertRows(insertStmt, pkIndexArray) insertRows, err := getInsertRows(insertStmt, pkIndexArray)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
insertNum := len(insertRows)
paramMap, err := u.buildImageParameters(insertStmt, args, insertRows) paramMap, err := u.buildImageParameters(insertStmt, args, insertRows)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

sql := strings.Builder{}
sql.WriteString("SELECT * FROM " + metaData.TableName + " ")
if len(paramMap) == 0 || len(metaData.Indexs) == 0 {
return "", nil, nil
}
hasPK := false
for _, index := range metaData.Indexs {
if strings.EqualFold("PRIMARY", index.Name) {
allPKColumnsHaveValue := true
for _, col := range index.Columns {
if params, ok := paramMap[col.ColumnName]; !ok || len(params) == 0 || params[0] == nil {
allPKColumnsHaveValue = false
break
}
}
hasPK = allPKColumnsHaveValue
break
}
}
if !hasPK {
hasValidUniqueIndex := false
for _, index := range metaData.Indexs {
if !index.NonUnique && !strings.EqualFold("PRIMARY", index.Name) {
if _, _, valid := validateIndexPrefix(index, paramMap, 0); valid {
hasValidUniqueIndex = true
break
}
}
}
if !hasValidUniqueIndex {
return "", nil, nil
}
}
var sql strings.Builder
sql.WriteString("SELECT * FROM " + metaData.TableName + " ")
var selectArgs []driver.Value
isContainWhere := false isContainWhere := false
for i := 0; i < insertNum; i++ {
finalI := i
paramAppenderTempList := make([]driver.Value, 0)
hasConditions := false
for i := 0; i < len(insertRows); i++ {
var rowConditions []string
var rowArgs []driver.Value
usedParams := make(map[string]bool)

// First try unique indexes
for _, index := range metaData.Indexs { for _, index := range metaData.Indexs {
//unique index
if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false {
if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) {
continue continue
} }
columnIsNull := true
uniqueList := make([]string, 0)
for _, columnMeta := range index.Columns {
columnName := strings.ToLower(columnMeta.ColumnName)
imageParameters, ok := paramMap[columnName]
if !ok && columnMeta.ColumnDef != nil {
if strings.EqualFold("PRIMARY", index.Name) {
u.BeforeImageSqlPrimaryKeys[columnName] = true
}
uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ")
columnIsNull = false
continue
}
if strings.EqualFold("PRIMARY", index.Name) {
u.BeforeImageSqlPrimaryKeys[columnName] = true
if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid {
rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")")
rowArgs = append(rowArgs, args...)
hasConditions = true
for _, colMeta := range index.Columns {
usedParams[colMeta.ColumnName] = true
} }
columnIsNull = false
uniqueList = append(uniqueList, columnName+" = ? ")
paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI])
} }
}


if !columnIsNull {
if isContainWhere {
sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ")
} else {
sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ")
isContainWhere = true
// Then try primary key
for _, index := range metaData.Indexs {
if !strings.EqualFold("PRIMARY", index.Name) {
continue
}
if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid {
rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")")
rowArgs = append(rowArgs, args...)
hasConditions = true
for _, colMeta := range index.Columns {
usedParams[colMeta.ColumnName] = true
} }
} }
} }
selectArgs = append(selectArgs, paramAppenderTempList...)

if len(rowConditions) > 0 {
if !isContainWhere {
sql.WriteString("WHERE ")
isContainWhere = true
} else {
sql.WriteString(" OR ")
}
sql.WriteString(strings.Join(rowConditions, " OR ") + " ")
selectArgs = append(selectArgs, rowArgs...)
}
}
if !hasConditions {
return "", nil, nil
} }
log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String())
return sql.String(), selectArgs, nil
sqlStr := sql.String()
log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr)
return sqlStr, selectArgs, nil
} }


func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
@@ -168,14 +208,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e
log.Errorf("build prepare stmt: %+v", err) log.Errorf("build prepare stmt: %+v", err)
return nil, err return nil, err
} }

defer stmt.Close()
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
metaData := execCtx.MetaDataMap[tableName]
rows, err := stmt.Query(selectArgs) rows, err := stmt.Query(selectArgs)
if err != nil { if err != nil {
log.Errorf("stmt query: %+v", err)
return nil, err return nil, err
} }
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
metaData := execCtx.MetaDataMap[tableName]
defer rows.Close()
image, err := u.buildRecordImages(rows, &metaData) image, err := u.buildRecordImages(rows, &metaData)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -185,11 +225,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e


func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) { func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) {
selectSQL, selectArgs := u.BeforeSelectSql, u.Args selectSQL, selectArgs := u.BeforeSelectSql, u.Args

var beforeImage *types.RecordImage var beforeImage *types.RecordImage
if len(beforeImages) > 0 { if len(beforeImages) > 0 {
beforeImage = beforeImages[0] beforeImage = beforeImages[0]
} }
if beforeImage == nil || len(beforeImage.Rows) == 0 {
return selectSQL, selectArgs
}
primaryValueMap := make(map[string][]interface{}) primaryValueMap := make(map[string][]interface{})
for _, row := range beforeImage.Rows { for _, row := range beforeImage.Rows {
for _, col := range row.Columns { for _, col := range row.Columns {
@@ -198,25 +240,46 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
} }
} }
} }

var afterImageSql strings.Builder var afterImageSql strings.Builder
var primaryValues []driver.Value
afterImageSql.WriteString(selectSQL) afterImageSql.WriteString(selectSQL)
for i := 0; i < len(beforeImage.Rows); i++ {
wherePrimaryList := make([]string, 0)
for name, value := range primaryValueMap {
if !u.BeforeImageSqlPrimaryKeys[name] {
wherePrimaryList = append(wherePrimaryList, name+" = ? ")
primaryValues = append(primaryValues, value[i])
if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) {
return selectSQL, selectArgs
}
var primaryValues []driver.Value
usedPrimaryKeys := make(map[string]bool)
for name := range primaryValueMap {
if !u.BeforeImageSqlPrimaryKeys[name] {
usedPrimaryKeys[name] = true
for i := 0; i < len(beforeImage.Rows); i++ {
if value := primaryValueMap[name][i]; value != nil {
if dv, ok := value.(driver.Value); ok {
primaryValues = append(primaryValues, dv)
} else {
primaryValues = append(primaryValues, value)
}
}
} }
} }
if len(wherePrimaryList) != 0 {
afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ")
}
if len(primaryValues) > 0 {
afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ")
}
finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues))
copy(finalArgs, selectArgs)
copy(finalArgs[len(selectArgs):], primaryValues)
sqlStr := afterImageSql.String()
log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr)
return sqlStr, finalArgs
}

func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string {
var conditions []string
for name := range primaryValueMap {
if !usedPrimaryKeys[name] {
conditions = append(conditions, name+" = ? ")
} }
} }
selectArgs = append(selectArgs, primaryValues...)
log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String())
return afterImageSql.String(), selectArgs
return conditions
} }


func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error {
@@ -243,11 +306,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e


// build sql params // build sql params
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) { func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) {
var (
parameterMap = make(map[string][]driver.Value)
)
parameterMap := make(map[string][]driver.Value)
insertColumns := getInsertColumns(insert) insertColumns := getInsertColumns(insert)
var placeHolderIndex = 0
placeHolderIndex := 0

for _, row := range insertRows { for _, row := range insertRows {
if len(row) != len(insertColumns) { if len(row) != len(insertColumns) {
log.Errorf("insert row's column size not equal to insert column size") log.Errorf("insert row's column size not equal to insert column size")
@@ -256,13 +318,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.
for i, col := range insertColumns { for i, col := range insertColumns {
columnName := strings.ToLower(executor.DelEscape(col, types.DBTypeMySQL)) columnName := strings.ToLower(executor.DelEscape(col, types.DBTypeMySQL))
val := row[i] val := row[i]
rStr, ok := val.(string)
if ok && strings.EqualFold(rStr, SqlPlaceholder) {
objects := args[placeHolderIndex]
parameterMap[columnName] = append(parameterMap[col], objects)
if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) {
if placeHolderIndex >= len(args) {
return nil, fmt.Errorf("not enough parameters for placeholders")
}
parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex])
placeHolderIndex++ placeHolderIndex++
} else { } else {
parameterMap[columnName] = append(parameterMap[col], val)
parameterMap[columnName] = append(parameterMap[columnName], val)
} }
} }
} }
@@ -296,3 +359,28 @@ func isIndexValueNotNull(indexMeta types.IndexMeta, imageParameterMap map[string
} }
return true return true
} }

func validateIndexPrefix(index types.IndexMeta, paramMap map[string][]driver.Value, rowIndex int) ([]string, []driver.Value, bool) {
var indexConditions []string
var indexArgs []driver.Value
if len(index.Columns) > 1 {
for _, colMeta := range index.Columns {
params, ok := paramMap[colMeta.ColumnName]
if !ok || len(params) <= rowIndex || params[rowIndex] == nil {
return nil, nil, false
}
}
}
for _, colMeta := range index.Columns {
columnName := colMeta.ColumnName
params, ok := paramMap[columnName]
if ok && len(params) > rowIndex && params[rowIndex] != nil {
indexConditions = append(indexConditions, columnName+" = ? ")
indexArgs = append(indexArgs, params[rowIndex])
}
}
if len(indexConditions) != len(index.Columns) {
return nil, nil, false
}
return indexConditions, indexArgs, true
}

+ 63
- 0
pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go View File

@@ -143,6 +143,69 @@ func TestInsertOnDuplicateBuildBeforeImageSQL(t *testing.T) {
expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ",
expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", int64(35)}, expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", int64(35)},
}, },
// Test case for null unique index
{
execCtx: &types.ExecContext{
Query: "insert into t_user(id, name, age) values(?, ?, ?) on duplicate key update age = ?",
MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1},
},
sourceQueryArgs: []driver.Value{1, nil, 2, 5},
expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ",
expectQueryArgs1: []driver.Value{1},
},
// Test case for null primary key
{
execCtx: &types.ExecContext{
Query: "insert into t_user(id, age) values(?, ?) on duplicate key update age = ?",
MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1},
},
sourceQueryArgs: []driver.Value{nil, 2, 5},
expectQuery1: "SELECT * FROM t_user WHERE (age = ? )",
expectQueryArgs1: []driver.Value{2},
},
// Test case for null unique index with no primary key
{
execCtx: &types.ExecContext{
Query: "insert into t_user(name, age) values(?, ?) on duplicate key update age = ?",
MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta2},
},
sourceQueryArgs: []driver.Value{nil, 2, 5},
expectQuery1: "",
expectQueryArgs1: nil,
},
// Test case for composite index with all columns
{
name: "composite_index_full",
execCtx: &types.ExecContext{
Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update other = ?",
MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1},
},
sourceQueryArgs: []driver.Value{1, "Jack", 25, "other"},
expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ",
expectQueryArgs1: []driver.Value{"Jack", 25, 1},
},
// Test case for composite index with null value
{
name: "composite_index_with_null",
execCtx: &types.ExecContext{
Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update other = ?",
MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1},
},
sourceQueryArgs: []driver.Value{1, "Jack", nil, "other"},
expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ",
expectQueryArgs1: []driver.Value{1},
},
// Test case for composite index with leftmost prefix only
{
name: "composite_index_leftmost_prefix",
execCtx: &types.ExecContext{
Query: "insert into t_user(id, name) values(?,?) on duplicate key update other = ?",
MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1},
},
sourceQueryArgs: []driver.Value{1, "Jack", "other"},
expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ",
expectQueryArgs1: []driver.Value{1},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {


+ 75
- 0
pkg/datasource/sql/util/lockkey.go View File

@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package util

import (
"fmt"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"strings"
)

func BuildLockKey(records *types.RecordImage, meta types.TableMeta) string {
var lockKeys strings.Builder
type ColMapItem struct {
pkIndex int
colIndex int
}

lockKeys.WriteString(meta.TableName)
lockKeys.WriteString(":")

keys := meta.GetPrimaryKeyOnlyName()
keyIndexMap := make(map[string]int, len(keys))
for idx, columnName := range keys {
keyIndexMap[columnName] = idx
}

columns := make([]ColMapItem, 0, len(keys))
if len(records.Rows) > 0 {
for colIdx, column := range records.Rows[0].Columns {
if pkIdx, ok := keyIndexMap[column.ColumnName]; ok {
columns = append(columns, ColMapItem{pkIndex: pkIdx, colIndex: colIdx})
}
}
for i, row := range records.Rows {
if i > 0 {
lockKeys.WriteString(",")
}
primaryKeyValues := make([]interface{}, len(keys))
for _, mp := range columns {
if mp.colIndex < len(row.Columns) {
primaryKeyValues[mp.pkIndex] = row.Columns[mp.colIndex].Value
}
}
for j, pkVal := range primaryKeyValues {
if j > 0 {
lockKeys.WriteString("_")
}
if pkVal == nil {
continue
}
lockKeys.WriteString(fmt.Sprintf("%v", pkVal))
}
}
}
return lockKeys.String()
}

Loading…
Cancel
Save