You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sql.go 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package executor
  18. import (
  19. "database/sql"
  20. "strings"
  21. "seata.apache.org/seata-go/pkg/datasource/sql/types"
  22. "seata.apache.org/seata-go/pkg/datasource/sql/undo"
  23. )
  24. const (
  25. Dot = "."
  26. EscapeStandard = "\""
  27. EscapeMysql = "`"
  28. )
  29. // DelEscape del escape by db type
  30. func DelEscape(colName string, dbType types.DBType) string {
  31. newColName := delEscape(colName, EscapeStandard)
  32. if dbType == types.DBTypeMySQL {
  33. newColName = delEscape(newColName, EscapeMysql)
  34. }
  35. return newColName
  36. }
  37. // delEscape
  38. func delEscape(colName string, escape string) string {
  39. if colName == "" {
  40. return ""
  41. }
  42. if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape {
  43. // like "scheme"."id" `scheme`.`id`
  44. str := escape + Dot + escape
  45. index := strings.Index(colName, str)
  46. if index > -1 {
  47. return colName[1:index] + Dot + colName[index+len(str):len(colName)-1]
  48. }
  49. return colName[1 : len(colName)-1]
  50. } else {
  51. // like "scheme".id `scheme`.id
  52. str := escape + Dot
  53. index := strings.Index(colName, str)
  54. if index > -1 && string(colName[0]) == escape {
  55. return colName[1:index] + Dot + colName[index+len(str):]
  56. }
  57. // like scheme."id" scheme.`id`
  58. str = Dot + escape
  59. index = strings.Index(colName, str)
  60. if index > -1 && string(colName[len(colName)-1]) == escape {
  61. return colName[0:index] + Dot + colName[index+len(str):len(colName)-1]
  62. }
  63. }
  64. return colName
  65. }
  66. // AddEscape if necessary, add escape by db type
  67. func AddEscape(colName string, dbType types.DBType) string {
  68. if dbType == types.DBTypeMySQL {
  69. return addEscape(colName, dbType, EscapeMysql)
  70. }
  71. return addEscape(colName, dbType, EscapeStandard)
  72. }
  73. func addEscape(colName string, dbType types.DBType, escape string) string {
  74. if colName == "" {
  75. return colName
  76. }
  77. if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape {
  78. return colName
  79. }
  80. if !checkEscape(colName, dbType) {
  81. return colName
  82. }
  83. if strings.Contains(colName, Dot) {
  84. // like "scheme".id `scheme`.id
  85. str := escape + Dot
  86. dotIndex := strings.Index(colName, str)
  87. if dotIndex > -1 {
  88. tempStr := strings.Builder{}
  89. tempStr.WriteString(colName[0 : dotIndex+len(str)])
  90. tempStr.WriteString(escape)
  91. tempStr.WriteString(colName[dotIndex+len(str):])
  92. tempStr.WriteString(escape)
  93. return tempStr.String()
  94. }
  95. // like scheme."id" scheme.`id`
  96. str = Dot + escape
  97. dotIndex = strings.Index(colName, str)
  98. if dotIndex > -1 {
  99. tempStr := strings.Builder{}
  100. tempStr.WriteString(escape)
  101. tempStr.WriteString(colName[0:dotIndex])
  102. tempStr.WriteString(escape)
  103. tempStr.WriteString(colName[dotIndex:])
  104. return tempStr.String()
  105. }
  106. str = Dot
  107. dotIndex = strings.Index(colName, str)
  108. if dotIndex > -1 {
  109. tempStr := strings.Builder{}
  110. tempStr.WriteString(escape)
  111. tempStr.WriteString(colName[0:dotIndex])
  112. tempStr.WriteString(escape)
  113. tempStr.WriteString(Dot)
  114. tempStr.WriteString(escape)
  115. tempStr.WriteString(colName[dotIndex+len(str):])
  116. tempStr.WriteString(escape)
  117. return tempStr.String()
  118. }
  119. }
  120. buf := make([]byte, len(colName)+2)
  121. buf[0], buf[len(buf)-1] = escape[0], escape[0]
  122. for key := range colName {
  123. buf[key+1] = colName[key]
  124. }
  125. return string(buf)
  126. }
  127. // checkEscape check whether given field or table name use keywords. the method has database special logic.
  128. func checkEscape(colName string, dbType types.DBType) bool {
  129. switch dbType {
  130. case types.DBTypeMySQL:
  131. if _, ok := types.GetMysqlKeyWord()[strings.ToUpper(colName)]; ok {
  132. return true
  133. }
  134. return false
  135. // TODO impl Oracle PG SQLServer ...
  136. default:
  137. return true
  138. }
  139. }
  140. // BuildWhereConditionByPKs each pk is a condition.the result will like :" id =? and userCode =?"
  141. func BuildWhereConditionByPKs(pkNameList []string, dbType types.DBType) string {
  142. whereStr := strings.Builder{}
  143. for i := 0; i < len(pkNameList); i++ {
  144. if i > 0 {
  145. whereStr.WriteString(" and ")
  146. }
  147. pkName := pkNameList[i]
  148. whereStr.WriteString(AddEscape(pkName, dbType))
  149. whereStr.WriteString(" = ? ")
  150. }
  151. return whereStr.String()
  152. }
  153. // DataValidationAndGoOn check data valid
  154. // Todo implement dataValidationAndGoOn
  155. func DataValidationAndGoOn(sqlUndoLog undo.SQLUndoLog, conn *sql.Conn) bool {
  156. return true
  157. }
  158. func GetOrderedPkList(image *types.RecordImage, row types.RowImage, dbType types.DBType) ([]types.ColumnImage, error) {
  159. pkColumnNameListByOrder := image.TableMeta.GetPrimaryKeyOnlyName()
  160. pkColumnNameListNoOrder := make([]types.ColumnImage, 0)
  161. pkFields := make([]types.ColumnImage, 0)
  162. for _, column := range row.PrimaryKeys(row.Columns) {
  163. column.ColumnName = DelEscape(column.ColumnName, dbType)
  164. pkColumnNameListNoOrder = append(pkColumnNameListNoOrder, column)
  165. }
  166. for _, pkName := range pkColumnNameListByOrder {
  167. for _, col := range pkColumnNameListNoOrder {
  168. if strings.Index(col.ColumnName, pkName) > -1 {
  169. pkFields = append(pkFields, col)
  170. }
  171. }
  172. }
  173. return pkFields, nil
  174. }