|
|
@@ -20,14 +20,15 @@ package exec |
|
|
|
import ( |
|
|
|
"context" |
|
|
|
"fmt" |
|
|
|
|
|
|
|
"github.com/pingcap/tidb/parser/ast" |
|
|
|
"github.com/pingcap/tidb/parser/format" |
|
|
|
"github.com/arana-db/parser/ast" |
|
|
|
"github.com/arana-db/parser/format" |
|
|
|
"github.com/seata/seata-go/pkg/common/bytes" |
|
|
|
"github.com/seata/seata-go/pkg/common/log" |
|
|
|
"github.com/seata/seata-go/pkg/datasource/sql/exec" |
|
|
|
"github.com/seata/seata-go/pkg/datasource/sql/parser" |
|
|
|
"github.com/seata/seata-go/pkg/datasource/sql/types" |
|
|
|
|
|
|
|
_ "github.com/arana-db/parser/test_driver" |
|
|
|
) |
|
|
|
|
|
|
|
type BasicUndoBuilder struct { |
|
|
@@ -81,3 +82,52 @@ func (u *BasicUndoBuilder) buildSelectSQLByUpdate(query string) (string, error) |
|
|
|
|
|
|
|
return sql, nil |
|
|
|
} |
|
|
|
|
|
|
|
// buildSelectSQLByUpdate build select sql from update sql |
|
|
|
func (u *BasicUndoBuilder) buildSelectSQLByInsert(query string) (string, error) { |
|
|
|
p, err := parser.DoParser(query) |
|
|
|
if err != nil { |
|
|
|
return "", err |
|
|
|
} |
|
|
|
|
|
|
|
if p.InsertStmt == nil { |
|
|
|
return "", fmt.Errorf("invalid Insert stmt") |
|
|
|
} |
|
|
|
|
|
|
|
InsertColumns := p.InsertStmt.Columns |
|
|
|
fields := []*ast.SelectField{} |
|
|
|
|
|
|
|
for _, column := range InsertColumns { |
|
|
|
fields = append(fields, &ast.SelectField{ |
|
|
|
Expr: &ast.ColumnNameExpr{ |
|
|
|
Name: column, |
|
|
|
}, |
|
|
|
}) |
|
|
|
} |
|
|
|
insertStmtList := p.InsertStmt.Lists |
|
|
|
var whereStmt ast.ExprNode |
|
|
|
|
|
|
|
whereList := []ast.ExprNode{} |
|
|
|
if len(insertStmtList) > 0 { |
|
|
|
whereList = p.InsertStmt.Lists[0] |
|
|
|
} |
|
|
|
|
|
|
|
if len(whereList) > 0 { |
|
|
|
whereStmt = whereList[0] |
|
|
|
} |
|
|
|
|
|
|
|
selStmt := ast.SelectStmt{ |
|
|
|
SelectStmtOpts: &ast.SelectStmtOpts{}, |
|
|
|
From: p.InsertStmt.Table, |
|
|
|
Where: whereStmt, |
|
|
|
Fields: &ast.FieldList{Fields: fields}, |
|
|
|
TableHints: p.InsertStmt.TableHints, |
|
|
|
} |
|
|
|
|
|
|
|
b := bytes.NewByteBuffer([]byte{}) |
|
|
|
selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) |
|
|
|
sql := string(b.Bytes()) |
|
|
|
log.Infof("build select sql by insert query, sql {}", sql) |
|
|
|
|
|
|
|
return sql, nil |
|
|
|
} |