You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

driver.go 5.9 kB


  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package sql
  18. import (
  19. "context"
  20. "database/sql"
  21. "database/sql/driver"
  22. "fmt"
  23. "reflect"
  24. "strings"
  25. "unsafe"
  26. "github.com/seata/seata-go/pkg/common/log"
  27. "github.com/go-sql-driver/mysql"
  28. "github.com/seata/seata-go/pkg/datasource/sql/datasource"
  29. "github.com/seata/seata-go/pkg/datasource/sql/types"
  30. "github.com/seata/seata-go/pkg/protocol/branch"
  31. )
  32. const (
  33. // SeataATMySQLDriver MySQL driver for AT mode
  34. SeataATMySQLDriver = "seata-at-mysql"
  35. // SeataXAMySQLDriver MySQL driver for XA mode
  36. SeataXAMySQLDriver = "seata-xa-mysql"
  37. )
  38. func init() {
  39. sql.Register(SeataATMySQLDriver, &seataATDriver{
  40. seataDriver: &seataDriver{
  41. transType: types.ATMode,
  42. target: mysql.MySQLDriver{},
  43. },
  44. })
  45. sql.Register(SeataXAMySQLDriver, &seataXADriver{
  46. seataDriver: &seataDriver{
  47. transType: types.XAMode,
  48. target: mysql.MySQLDriver{},
  49. },
  50. })
  51. }
  52. type seataATDriver struct {
  53. *seataDriver
  54. }
  55. func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) {
  56. connector, err := d.seataDriver.OpenConnector(name)
  57. if err != nil {
  58. return nil, err
  59. }
  60. _connector, _ := connector.(*seataConnector)
  61. _connector.transType = types.ATMode
  62. return &seataATConnector{
  63. seataConnector: _connector,
  64. }, nil
  65. }
  66. type seataXADriver struct {
  67. *seataDriver
  68. }
  69. func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) {
  70. connector, err := d.seataDriver.OpenConnector(name)
  71. if err != nil {
  72. return nil, err
  73. }
  74. _connector, _ := connector.(*seataConnector)
  75. _connector.transType = types.XAMode
  76. return &seataXAConnector{
  77. seataConnector: _connector,
  78. }, nil
  79. }
  80. type seataDriver struct {
  81. transType types.TransactionType
  82. target driver.Driver
  83. }
  84. func (d *seataDriver) Open(name string) (driver.Conn, error) {
  85. conn, err := d.target.Open(name)
  86. if err != nil {
  87. log.Errorf("open target connection: %w", err)
  88. return nil, err
  89. }
  90. v := reflect.ValueOf(conn)
  91. if v.Kind() == reflect.Ptr {
  92. v = v.Elem()
  93. }
  94. field := v.FieldByName("connector")
  95. proxy, err := d.OpenConnector(name)
  96. if err != nil {
  97. log.Errorf("open connector: %w", err)
  98. return nil, err
  99. }
  100. SetUnexportedField(field, proxy)
  101. return conn, nil
  102. }
  103. func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) {
  104. c = &dsnConnector{dsn: name, driver: d.target}
  105. if driverCtx, ok := d.target.(driver.DriverContext); ok {
  106. c, err = driverCtx.OpenConnector(name)
  107. if err != nil {
  108. log.Errorf("open connector: %w", err)
  109. return nil, err
  110. }
  111. }
  112. dbType := types.ParseDBType(d.getTargetDriverName())
  113. if dbType == types.DBTypeUnknown {
  114. return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
  115. }
  116. proxy, err := registerResource(c, d.transType, dbType, sql.OpenDB(c), name)
  117. if err != nil {
  118. log.Errorf("register resource: %w", err)
  119. return nil, err
  120. }
  121. return proxy, nil
  122. }
  123. func (d *seataDriver) getTargetDriverName() string {
  124. return "mysql"
  125. }
  126. type dsnConnector struct {
  127. dsn string
  128. driver driver.Driver
  129. }
  130. func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
  131. return t.driver.Open(t.dsn)
  132. }
  133. func (t *dsnConnector) Driver() driver.Driver {
  134. return t.driver
  135. }
  136. func registerResource(connector driver.Connector, txType types.TransactionType, dbType types.DBType, db *sql.DB,
  137. dataSourceName string, opts ...seataOption) (driver.Connector, error) {
  138. conf := loadConfig()
  139. for i := range opts {
  140. opts[i](conf)
  141. }
  142. if err := conf.validate(); err != nil {
  143. log.Errorf("invalid conf: %w", err)
  144. return connector, err
  145. }
  146. options := []dbOption{
  147. withGroupID(conf.GroupID),
  148. withResourceID(parseResourceID(dataSourceName)),
  149. withConf(conf),
  150. withTarget(db),
  151. withDBType(dbType),
  152. }
  153. res, err := newResource(options...)
  154. if err != nil {
  155. log.Errorf("create new resource: %w", err)
  156. return nil, err
  157. }
  158. if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil {
  159. log.Errorf("regisiter resource: %w", err)
  160. return nil, err
  161. }
  162. return &seataConnector{
  163. res: res,
  164. target: connector,
  165. conf: conf,
  166. }, nil
  167. }
  168. type (
  169. seataOption func(cfg *seataServerConfig)
  170. // seataServerConfig
  171. seataServerConfig struct {
  172. // GroupID
  173. GroupID string `yaml:"groupID"`
  174. // BranchType
  175. BranchType branch.BranchType
  176. // Endpoints
  177. Endpoints []string `yaml:"endpoints" json:"endpoints"`
  178. }
  179. )
  180. func (c *seataServerConfig) validate() error {
  181. return nil
  182. }
  183. // loadConfig
  184. // TODO wait finish
  185. func loadConfig() *seataServerConfig {
  186. // 先设置默认配置
  187. // 从默认文件获取
  188. return &seataServerConfig{
  189. GroupID: "DEFAULT_GROUP",
  190. BranchType: branch.BranchTypeAT,
  191. Endpoints: []string{"127.0.0.1:8888"},
  192. }
  193. }
  194. func parseResourceID(dsn string) string {
  195. i := strings.Index(dsn, "?")
  196. res := dsn
  197. if i > 0 {
  198. res = dsn[:i]
  199. }
  200. return strings.ReplaceAll(res, ",", "|")
  201. }
  202. func GetUnexportedField(field reflect.Value) interface{} {
  203. return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
  204. }
  205. func SetUnexportedField(field reflect.Value, value interface{}) {
  206. reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).
  207. Elem().
  208. Set(reflect.ValueOf(value))
  209. }