|
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- package at
-
- import (
- "context"
- "database/sql/driver"
- "fmt"
- "strings"
-
- "github.com/arana-db/parser/ast"
-
- "seata.apache.org/seata-go/pkg/datasource/sql/datasource"
- "seata.apache.org/seata-go/pkg/datasource/sql/exec"
- "seata.apache.org/seata-go/pkg/datasource/sql/types"
- "seata.apache.org/seata-go/pkg/datasource/sql/util"
- "seata.apache.org/seata-go/pkg/util/log"
- )
-
- const (
- sqlPlaceholder = "?"
- )
-
- // insertExecutor execute insert SQL
- type insertExecutor struct {
- baseExecutor
- parserCtx *types.ParseContext
- execContext *types.ExecContext
- incrementStep int
- // businesSQLResult after insert sql
- businesSQLResult types.ExecResult
- }
-
- // NewInsertExecutor get insert executor
- func NewInsertExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
- return &insertExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}
- }
-
- func (i *insertExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
- i.beforeHooks(ctx, i.execContext)
- defer func() {
- i.afterHooks(ctx, i.execContext)
- }()
-
- beforeImage, err := i.beforeImage(ctx)
- if err != nil {
- return nil, err
- }
-
- res, err := f(ctx, i.execContext.Query, i.execContext.NamedValues)
- if err != nil {
- return nil, err
- }
-
- if i.businesSQLResult == nil {
- i.businesSQLResult = res
- }
-
- afterImage, err := i.afterImage(ctx)
- if err != nil {
- return nil, err
- }
-
- i.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
- i.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage)
- return res, nil
- }
-
- // beforeImage build before image
- func (i *insertExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) {
- tableName, _ := i.parserCtx.GetTableName()
- metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
- if err != nil {
- return nil, err
- }
- return types.NewEmptyRecordImage(metaData, types.SQLTypeInsert), nil
- }
-
- // afterImage build after image
- func (i *insertExecutor) afterImage(ctx context.Context) (*types.RecordImage, error) {
- if !i.isAstStmtValid() {
- return nil, nil
- }
-
- tableName, _ := i.parserCtx.GetTableName()
- metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
- if err != nil {
- return nil, err
- }
- selectSQL, selectArgs, err := i.buildAfterImageSQL(ctx)
- if err != nil {
- return nil, err
- }
-
- var rowsi driver.Rows
- queryerCtx, ok := i.execContext.Conn.(driver.QueryerContext)
- var queryer driver.Queryer
- if !ok {
- queryer, ok = i.execContext.Conn.(driver.Queryer)
- }
- if ok {
- rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
- defer func() {
- if rowsi != nil {
- rowsi.Close()
- }
- }()
- if err != nil {
- log.Errorf("ctx driver query: %+v", err)
- return nil, err
- }
- } else {
- log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
- return nil, fmt.Errorf("invalid conn")
- }
-
- image, err := i.buildRecordImages(rowsi, metaData, types.SQLTypeInsert)
- if err != nil {
- return nil, err
- }
-
- lockKey := i.buildLockKey(image, *metaData)
- i.execContext.TxCtx.LockKeys[lockKey] = struct{}{}
- return image, nil
- }
-
- // buildAfterImageSQL build select sql from insert sql
- func (i *insertExecutor) buildAfterImageSQL(ctx context.Context) (string, []driver.NamedValue, error) {
- // get all pk value
- tableName, _ := i.parserCtx.GetTableName()
-
- meta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
- if err != nil {
- return "", nil, err
- }
- pkValuesMap, err := i.getPkValues(ctx, i.execContext, i.parserCtx, *meta)
- if err != nil {
- return "", nil, err
- }
-
- pkColumnNameList := meta.GetPrimaryKeyOnlyName()
- if len(pkColumnNameList) == 0 {
- return "", nil, fmt.Errorf("Pk columnName size is zero")
- }
-
- dataTypeMap, err := meta.GetPrimaryKeyTypeStrMap()
- if err != nil {
- return "", nil, err
- }
- if len(dataTypeMap) != len(pkColumnNameList) {
- return "", nil, fmt.Errorf("PK columnName size don't equal PK DataType size")
- }
- var pkRowImages []types.RowImage
-
- rowSize := len(pkValuesMap[pkColumnNameList[0]])
- for i := 0; i < rowSize; i++ {
- for _, name := range pkColumnNameList {
- tmpKey := name
- tmpArray := pkValuesMap[tmpKey]
- pkRowImages = append(pkRowImages, types.RowImage{
- Columns: []types.ColumnImage{{
- KeyType: types.IndexTypePrimaryKey,
- ColumnName: tmpKey,
- ColumnType: types.MySQLStrToJavaType(dataTypeMap[tmpKey]),
- Value: tmpArray[i],
- }},
- })
- }
- }
- // build check sql
- sb := strings.Builder{}
- sb.WriteString("SELECT * FROM " + tableName)
- whereSQL := i.buildWhereConditionByPKs(pkColumnNameList, len(pkValuesMap[pkColumnNameList[0]]), "mysql", maxInSize)
- sb.WriteString(" WHERE " + whereSQL + " ")
- return sb.String(), i.buildPKParams(pkRowImages, pkColumnNameList), nil
- }
-
- func (i *insertExecutor) getPkValues(ctx context.Context, execCtx *types.ExecContext, parseCtx *types.ParseContext, meta types.TableMeta) (map[string][]interface{}, error) {
- pkColumnNameList := meta.GetPrimaryKeyOnlyName()
- pkValuesMap := make(map[string][]interface{})
- var err error
- // when there is only one pk in the table
- if len(pkColumnNameList) == 1 {
- if i.containsPK(meta, parseCtx) {
- // the insert sql contain pk value
- pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- } else if containsColumns(parseCtx) {
- // the insert table pk auto generated
- pkValuesMap, err = i.getPkValuesByAuto(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- } else {
- pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- }
- } else {
- // when there is multiple pk in the table
- // 1,all pk columns are filled value.
- // 2,the auto increment pk column value is null, and other pk value are not null.
- pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- for _, columnName := range pkColumnNameList {
- if _, ok := pkValuesMap[columnName]; !ok {
- curPkValuesMap, err := i.getPkValuesByAuto(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- pkValuesMapMerge(&pkValuesMap, curPkValuesMap)
- }
- }
- }
- return pkValuesMap, nil
- }
-
- // containsPK the columns contains table meta pk
- func (i *insertExecutor) containsPK(meta types.TableMeta, parseCtx *types.ParseContext) bool {
- pkColumnNameList := meta.GetPrimaryKeyOnlyName()
- if len(pkColumnNameList) == 0 {
- return false
- }
- if parseCtx == nil || parseCtx.InsertStmt == nil || parseCtx.InsertStmt.Columns == nil {
- return false
- }
- if len(parseCtx.InsertStmt.Columns) == 0 {
- return false
- }
-
- matchCounter := 0
- for _, column := range parseCtx.InsertStmt.Columns {
- for _, pkName := range pkColumnNameList {
- if strings.EqualFold(pkName, column.Name.O) ||
- strings.EqualFold(pkName, column.Name.L) {
- matchCounter++
- }
- }
- }
-
- return matchCounter == len(pkColumnNameList)
- }
-
- // containPK compare column name and primary key name
- func (i *insertExecutor) containPK(columnName string, meta types.TableMeta) bool {
- newColumnName := DelEscape(columnName, types.DBTypeMySQL)
- pkColumnNameList := meta.GetPrimaryKeyOnlyName()
- if len(pkColumnNameList) == 0 {
- return false
- }
- for _, name := range pkColumnNameList {
- if strings.EqualFold(name, newColumnName) {
- return true
- }
- }
- return false
- }
-
- // getPkIndex get pk index
- // return the key is pk column name and the value is index of the pk column
- func (i *insertExecutor) getPkIndex(InsertStmt *ast.InsertStmt, meta types.TableMeta) map[string]int {
- pkIndexMap := make(map[string]int)
- if InsertStmt == nil {
- return pkIndexMap
- }
- insertColumnsSize := len(InsertStmt.Columns)
- if insertColumnsSize == 0 {
- return pkIndexMap
- }
- if meta.ColumnNames == nil {
- return pkIndexMap
- }
- if len(meta.Columns) > 0 {
- for paramIdx := 0; paramIdx < insertColumnsSize; paramIdx++ {
- sqlColumnName := InsertStmt.Columns[paramIdx].Name.O
- if i.containPK(sqlColumnName, meta) {
- pkIndexMap[sqlColumnName] = paramIdx
- }
- }
- return pkIndexMap
- }
-
- pkIndex := -1
- allColumns := meta.Columns
- for _, columnMeta := range allColumns {
- tmpColumnMeta := columnMeta
- pkIndex++
- if i.containPK(tmpColumnMeta.ColumnName, meta) {
- pkIndexMap[DelEscape(tmpColumnMeta.ColumnName, types.DBTypeMySQL)] = pkIndex
- }
- }
-
- return pkIndexMap
- }
-
- // parsePkValuesFromStatement parse primary key value from statement.
- // return the primary key and values<key:primary key,value:primary key values></key:primary>
- func (i *insertExecutor) parsePkValuesFromStatement(insertStmt *ast.InsertStmt, meta types.TableMeta, nameValues []driver.NamedValue) (map[string][]interface{}, error) {
- if insertStmt == nil {
- return nil, nil
- }
- pkIndexMap := i.getPkIndex(insertStmt, meta)
- if pkIndexMap == nil || len(pkIndexMap) == 0 {
- return nil, fmt.Errorf("pkIndex is not found")
- }
- var pkIndexArray []int
- for _, val := range pkIndexMap {
- tmpVal := val
- pkIndexArray = append(pkIndexArray, tmpVal)
- }
-
- if insertStmt == nil || len(insertStmt.Lists) == 0 {
- return nil, fmt.Errorf("parCtx is nil, perhaps InsertStmt is empty")
- }
-
- pkValuesMap := make(map[string][]interface{})
-
- if nameValues != nil && len(nameValues) > 0 {
- // use prepared statements
- insertRows, err := getInsertRows(insertStmt, pkIndexArray)
- if err != nil {
- return nil, err
- }
- if insertRows == nil || len(insertRows) == 0 {
- return nil, err
- }
- totalPlaceholderNum := -1
- for _, row := range insertRows {
- if len(row) == 0 {
- continue
- }
- currentRowPlaceholderNum := -1
- for _, r := range row {
- rStr, ok := r.(string)
- if ok && strings.EqualFold(rStr, sqlPlaceholder) {
- totalPlaceholderNum += 1
- currentRowPlaceholderNum += 1
- }
- }
- var pkKey string
- var pkIndex int
- var pkValues []interface{}
- for key, index := range pkIndexMap {
- curKey := key
- curIndex := index
-
- pkKey = curKey
- pkValues = pkValuesMap[pkKey]
-
- pkIndex = curIndex
- if pkIndex > len(row)-1 {
- continue
- }
- pkValue := row[pkIndex]
- pkValueStr, ok := pkValue.(string)
- if ok && strings.EqualFold(pkValueStr, sqlPlaceholder) {
- currentRowNotPlaceholderNumBeforePkIndex := 0
- for i := range row {
- r := row[i]
- rStr, ok := r.(string)
- if i < pkIndex && ok && !strings.EqualFold(rStr, sqlPlaceholder) {
- currentRowNotPlaceholderNumBeforePkIndex++
- }
- }
- idx := totalPlaceholderNum - currentRowPlaceholderNum + pkIndex - currentRowNotPlaceholderNumBeforePkIndex
- pkValues = append(pkValues, nameValues[idx].Value)
- } else {
- pkValues = append(pkValues, pkValue)
- }
- if _, ok := pkValuesMap[pkKey]; !ok {
- pkValuesMap[pkKey] = pkValues
- }
- }
- }
- } else {
- for _, list := range insertStmt.Lists {
- for pkName, pkIndex := range pkIndexMap {
- tmpPkName := pkName
- tmpPkIndex := pkIndex
- if tmpPkIndex >= len(list) {
- return nil, fmt.Errorf("pkIndex out of range")
- }
- if node, ok := list[tmpPkIndex].(ast.ValueExpr); ok {
- pkValuesMap[tmpPkName] = append(pkValuesMap[tmpPkName], node.GetValue())
- }
- }
- }
- }
-
- return pkValuesMap, nil
- }
-
- // getPkValuesByColumn get pk value by column.
- func (i *insertExecutor) getPkValuesByColumn(ctx context.Context, execCtx *types.ExecContext) (map[string][]interface{}, error) {
- if !i.isAstStmtValid() {
- return nil, nil
- }
-
- tableName, _ := i.parserCtx.GetTableName()
- meta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
- if err != nil {
- return nil, err
- }
- pkValuesMap, err := i.parsePkValuesFromStatement(i.parserCtx.InsertStmt, *meta, execCtx.NamedValues)
- if err != nil {
- return nil, err
- }
-
- // generate pkValue by auto increment
- for _, v := range pkValuesMap {
- tmpV := v
- if len(tmpV) == 1 {
- // pk auto generated while single insert primary key is expression
- if _, ok := tmpV[0].(*ast.FuncCallExpr); ok {
- curPkValueMap, err := i.getPkValuesByAuto(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- pkValuesMapMerge(&pkValuesMap, curPkValueMap)
- }
- } else if len(tmpV) > 0 && tmpV[0] == nil {
- // pk auto generated while column exists and value is null
- curPkValueMap, err := i.getPkValuesByAuto(ctx, execCtx)
- if err != nil {
- return nil, err
- }
- pkValuesMapMerge(&pkValuesMap, curPkValueMap)
- }
- }
- return pkValuesMap, nil
- }
-
- func (i *insertExecutor) getPkValuesByAuto(ctx context.Context, execCtx *types.ExecContext) (map[string][]interface{}, error) {
- if !i.isAstStmtValid() {
- return nil, nil
- }
-
- tableName, _ := i.parserCtx.GetTableName()
- metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
- if err != nil {
- return nil, err
- }
-
- pkValuesMap := make(map[string][]interface{})
- pkMetaMap := metaData.GetPrimaryKeyMap()
- if len(pkMetaMap) == 0 {
- return nil, fmt.Errorf("pk map is empty")
- }
- var autoColumnName string
- for _, columnMeta := range pkMetaMap {
- tmpColumnMeta := columnMeta
- if tmpColumnMeta.Autoincrement {
- autoColumnName = tmpColumnMeta.ColumnName
- break
- }
- }
- if len(autoColumnName) == 0 {
- return nil, fmt.Errorf("auto increment column not exist")
- }
-
- updateCount, err := i.businesSQLResult.GetResult().RowsAffected()
- if err != nil {
- return nil, err
- }
-
- lastInsertId, err := i.businesSQLResult.GetResult().LastInsertId()
- if err != nil {
- return nil, err
- }
-
- // If there is batch insert
- // do auto increment base LAST_INSERT_ID and variable `auto_increment_increment`
- if lastInsertId > 0 && updateCount > 1 && canAutoIncrement(pkMetaMap) {
- return i.autoGeneratePks(execCtx, autoColumnName, lastInsertId, updateCount)
- }
-
- if lastInsertId > 0 {
- var pkValues []interface{}
- pkValues = append(pkValues, lastInsertId)
- pkValuesMap[autoColumnName] = pkValues
- return pkValuesMap, nil
- }
-
- return nil, nil
- }
-
- func canAutoIncrement(pkMetaMap map[string]types.ColumnMeta) bool {
- if len(pkMetaMap) != 1 {
- return false
- }
- for _, meta := range pkMetaMap {
- return meta.Autoincrement
- }
- return false
- }
-
- func (i *insertExecutor) isAstStmtValid() bool {
- return i.parserCtx != nil && i.parserCtx.InsertStmt != nil
- }
-
- func (i *insertExecutor) autoGeneratePks(execCtx *types.ExecContext, autoColumnName string, lastInsetId, updateCount int64) (map[string][]interface{}, error) {
- var step int64
- if i.incrementStep > 0 {
- step = int64(i.incrementStep)
- } else {
- // get step by query sql
- stmt, err := execCtx.Conn.Prepare("SHOW VARIABLES LIKE 'auto_increment_increment'")
- if err != nil {
- log.Errorf("build prepare stmt: %+v", err)
- return nil, err
- }
-
- rows, err := stmt.Query(nil)
- if err != nil {
- log.Errorf("stmt query: %+v", err)
- return nil, err
- }
-
- if len(rows.Columns()) > 0 {
- var curStep []driver.Value
- if err := rows.Next(curStep); err != nil {
- return nil, err
- }
-
- if curStepInt, ok := curStep[0].(int64); ok {
- step = curStepInt
- }
- } else {
- return nil, fmt.Errorf("query is empty")
- }
- }
-
- if step == 0 {
- return nil, fmt.Errorf("get increment step error")
- }
-
- var pkValues []interface{}
- for j := int64(0); j < updateCount; j++ {
- pkValues = append(pkValues, lastInsetId+j*step)
- }
- pkValuesMap := make(map[string][]interface{})
- pkValuesMap[autoColumnName] = pkValues
- return pkValuesMap, nil
- }
-
- func pkValuesMapMerge(dest *map[string][]interface{}, src map[string][]interface{}) {
- for k, v := range src {
- tmpK := k
- tmpV := v
- (*dest)[tmpK] = append((*dest)[tmpK], tmpV)
- }
- }
-
- // containsColumns judge sql specify column
- func containsColumns(parseCtx *types.ParseContext) bool {
- if parseCtx == nil || parseCtx.InsertStmt == nil || parseCtx.InsertStmt.Columns == nil {
- return false
- }
- return len(parseCtx.InsertStmt.Columns) > 0
- }
-
- func getInsertRows(insertStmt *ast.InsertStmt, pkIndexArray []int) ([][]interface{}, error) {
- if insertStmt == nil {
- return nil, nil
- }
- if len(insertStmt.Lists) == 0 {
- return nil, nil
- }
- var rows [][]interface{}
-
- for _, nodes := range insertStmt.Lists {
- var row []interface{}
- for i, node := range nodes {
- if _, ok := node.(ast.ParamMarkerExpr); ok {
- row = append(row, sqlPlaceholder)
- } else if newNode, ok := node.(ast.ValueExpr); ok {
- row = append(row, newNode.GetValue())
- } else if newNode, ok := node.(*ast.VariableExpr); ok {
- row = append(row, newNode.Name)
- } else if _, ok := node.(*ast.FuncCallExpr); ok {
- row = append(row, ast.FuncCallExpr{})
- } else {
- for _, index := range pkIndexArray {
- if index == i {
- return nil, fmt.Errorf("Unknown SQLExpr:%v", node)
- }
- }
- row = append(row, ast.DefaultExpr{})
- }
- }
- rows = append(rows, row)
- }
- return rows, nil
- }
|