You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

update_executor.go 8.3 kB


  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package at
  18. import (
  19. "context"
  20. "database/sql/driver"
  21. "fmt"
  22. "strings"
  23. "github.com/arana-db/parser/ast"
  24. "github.com/arana-db/parser/format"
  25. "github.com/arana-db/parser/model"
  26. "seata.apache.org/seata-go/pkg/datasource/sql/datasource"
  27. "seata.apache.org/seata-go/pkg/datasource/sql/exec"
  28. "seata.apache.org/seata-go/pkg/datasource/sql/types"
  29. "seata.apache.org/seata-go/pkg/datasource/sql/undo"
  30. "seata.apache.org/seata-go/pkg/datasource/sql/util"
  31. "seata.apache.org/seata-go/pkg/util/bytes"
  32. "seata.apache.org/seata-go/pkg/util/log"
  33. )
  34. var (
  35. maxInSize = 1000
  36. )
  37. // updateExecutor execute update SQL
  38. type updateExecutor struct {
  39. baseExecutor
  40. parserCtx *types.ParseContext
  41. execContext *types.ExecContext
  42. }
  43. // NewUpdateExecutor get update executor
  44. func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
  45. return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}
  46. }
  47. // ExecContext exec SQL, and generate before image and after image
  48. func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
  49. u.beforeHooks(ctx, u.execContext)
  50. defer func() {
  51. u.afterHooks(ctx, u.execContext)
  52. }()
  53. beforeImage, err := u.beforeImage(ctx)
  54. if err != nil {
  55. return nil, err
  56. }
  57. res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues)
  58. if err != nil {
  59. return nil, err
  60. }
  61. afterImage, err := u.afterImage(ctx, *beforeImage)
  62. if err != nil {
  63. return nil, err
  64. }
  65. if len(beforeImage.Rows) != len(afterImage.Rows) {
  66. return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.")
  67. }
  68. u.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
  69. u.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage)
  70. return res, nil
  71. }
  72. // beforeImage build before image
  73. func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) {
  74. if !u.isAstStmtValid() {
  75. return nil, nil
  76. }
  77. selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues)
  78. if err != nil {
  79. return nil, err
  80. }
  81. tableName, _ := u.parserCtx.GetTableName()
  82. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
  83. if err != nil {
  84. return nil, err
  85. }
  86. var rowsi driver.Rows
  87. queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
  88. var queryer driver.Queryer
  89. if !ok {
  90. queryer, ok = u.execContext.Conn.(driver.Queryer)
  91. }
  92. if ok {
  93. rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
  94. defer func() {
  95. if rowsi != nil {
  96. rowsi.Close()
  97. }
  98. }()
  99. if err != nil {
  100. log.Errorf("ctx driver query: %+v", err)
  101. return nil, err
  102. }
  103. } else {
  104. log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
  105. return nil, fmt.Errorf("invalid conn")
  106. }
  107. image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
  108. if err != nil {
  109. return nil, err
  110. }
  111. lockKey := u.buildLockKey(image, *metaData)
  112. u.execContext.TxCtx.LockKeys[lockKey] = struct{}{}
  113. image.SQLType = u.parserCtx.SQLType
  114. return image, nil
  115. }
  116. // afterImage build after image
  117. func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) {
  118. if !u.isAstStmtValid() {
  119. return nil, nil
  120. }
  121. if len(beforeImage.Rows) == 0 {
  122. return &types.RecordImage{}, nil
  123. }
  124. tableName, _ := u.parserCtx.GetTableName()
  125. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
  126. if err != nil {
  127. return nil, err
  128. }
  129. selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData)
  130. var rowsi driver.Rows
  131. queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
  132. var queryer driver.Queryer
  133. if !ok {
  134. queryer, ok = u.execContext.Conn.(driver.Queryer)
  135. }
  136. if ok {
  137. rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
  138. defer func() {
  139. if rowsi != nil {
  140. rowsi.Close()
  141. }
  142. }()
  143. if err != nil {
  144. log.Errorf("ctx driver query: %+v", err)
  145. return nil, err
  146. }
  147. } else {
  148. log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
  149. return nil, fmt.Errorf("invalid conn")
  150. }
  151. afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
  152. if err != nil {
  153. return nil, err
  154. }
  155. afterImage.SQLType = u.parserCtx.SQLType
  156. return afterImage, nil
  157. }
  158. func (u *updateExecutor) isAstStmtValid() bool {
  159. return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil
  160. }
  161. // buildAfterImageSQL build the SQL to query after image data
  162. func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) {
  163. if len(beforeImage.Rows) == 0 {
  164. return "", nil
  165. }
  166. sb := strings.Builder{}
  167. // todo: OnlyCareUpdateColumns should load from config first
  168. var selectFields string
  169. var separator = ","
  170. if undo.UndoConfig.OnlyCareUpdateColumns {
  171. for _, row := range beforeImage.Rows {
  172. for _, column := range row.Columns {
  173. selectFields += column.ColumnName + separator
  174. }
  175. }
  176. selectFields = strings.TrimSuffix(selectFields, separator)
  177. } else {
  178. selectFields = "*"
  179. }
  180. sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ")
  181. whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize)
  182. sb.WriteString(" " + whereSQL + " ")
  183. return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName())
  184. }
  185. // buildAfterImageSQL build the SQL to query before image data
  186. func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) {
  187. if !u.isAstStmtValid() {
  188. log.Errorf("invalid update stmt")
  189. return "", nil, fmt.Errorf("invalid update stmt")
  190. }
  191. updateStmt := u.parserCtx.UpdateStmt
  192. fields := make([]*ast.SelectField, 0, len(updateStmt.List))
  193. if undo.UndoConfig.OnlyCareUpdateColumns {
  194. for _, column := range updateStmt.List {
  195. fields = append(fields, &ast.SelectField{
  196. Expr: &ast.ColumnNameExpr{
  197. Name: column.Column,
  198. },
  199. })
  200. }
  201. // select indexes columns
  202. tableName, _ := u.parserCtx.GetTableName()
  203. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
  204. if err != nil {
  205. return "", nil, err
  206. }
  207. for _, columnName := range metaData.GetPrimaryKeyOnlyName() {
  208. fields = append(fields, &ast.SelectField{
  209. Expr: &ast.ColumnNameExpr{
  210. Name: &ast.ColumnName{
  211. Name: model.CIStr{
  212. O: columnName,
  213. L: columnName,
  214. },
  215. },
  216. },
  217. })
  218. }
  219. } else {
  220. fields = append(fields, &ast.SelectField{
  221. Expr: &ast.ColumnNameExpr{
  222. Name: &ast.ColumnName{
  223. Name: model.CIStr{
  224. O: "*",
  225. L: "*",
  226. },
  227. },
  228. },
  229. })
  230. }
  231. selStmt := ast.SelectStmt{
  232. SelectStmtOpts: &ast.SelectStmtOpts{},
  233. From: updateStmt.TableRefs,
  234. Where: updateStmt.Where,
  235. Fields: &ast.FieldList{Fields: fields},
  236. OrderBy: updateStmt.Order,
  237. Limit: updateStmt.Limit,
  238. TableHints: updateStmt.TableHints,
  239. LockInfo: &ast.SelectLockInfo{
  240. LockType: ast.SelectLockForUpdate,
  241. },
  242. }
  243. b := bytes.NewByteBuffer([]byte{})
  244. _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
  245. sql := string(b.Bytes())
  246. log.Infof("build select sql by update sourceQuery, sql {%s}", sql)
  247. return sql, u.buildSelectArgs(&selStmt, args), nil
  248. }