You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inmemory.go 7.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package inmemory
  2. import (
  3. "fmt"
  4. "sync"
  5. "github.com/opentrx/seata-golang/v2/pkg/apis"
  6. "github.com/opentrx/seata-golang/v2/pkg/tc/model"
  7. "github.com/opentrx/seata-golang/v2/pkg/tc/storage"
  8. "github.com/opentrx/seata-golang/v2/pkg/tc/storage/driver/factory"
  9. "github.com/opentrx/seata-golang/v2/pkg/util/log"
  10. )
  11. func init() {
  12. factory.Register("inmemory", &inMemoryFactory{})
  13. }
  14. // inMemoryFactory implements the factory.StorageDriverFactory interface
  15. type inMemoryFactory struct{}
  16. func (factory *inMemoryFactory) Create(parameters map[string]interface{}) (storage.Driver, error) {
  17. return &driver{
  18. SessionMap: &sync.Map{},
  19. LockMap: &sync.Map{},
  20. }, nil
  21. }
  22. // inMemory driver only for testing
  23. type driver struct {
  24. SessionMap *sync.Map
  25. LockMap *sync.Map
  26. }
  27. // Add global session.
  28. func (driver *driver) AddGlobalSession(session *apis.GlobalSession) error {
  29. driver.SessionMap.Store(session.XID, &model.GlobalTransaction{
  30. GlobalSession: session,
  31. BranchSessions: make(map[*apis.BranchSession]bool),
  32. })
  33. return nil
  34. }
  35. // Find global session.
  36. func (driver *driver) FindGlobalSession(xid string) *apis.GlobalSession {
  37. globalTransaction, ok := driver.SessionMap.Load(xid)
  38. if ok {
  39. gt := globalTransaction.(*model.GlobalTransaction)
  40. return gt.GlobalSession
  41. }
  42. return nil
  43. }
  44. // Find global sessions list.
  45. func (driver *driver) FindGlobalSessions(statuses []apis.GlobalSession_GlobalStatus) []*apis.GlobalSession {
  46. contains := func(statuses []apis.GlobalSession_GlobalStatus, status apis.GlobalSession_GlobalStatus) bool {
  47. for _, s := range statuses {
  48. if s == status {
  49. return true
  50. }
  51. }
  52. return false
  53. }
  54. var sessions = make([]*apis.GlobalSession, 0)
  55. driver.SessionMap.Range(func(key, value interface{}) bool {
  56. session := value.(*model.GlobalTransaction)
  57. if contains(statuses, session.Status) {
  58. sessions = append(sessions, session.GlobalSession)
  59. }
  60. return true
  61. })
  62. return sessions
  63. }
  64. // Find global sessions list with addressing identities
  65. func (driver *driver) FindGlobalSessionsWithAddressingIdentities(statuses []apis.GlobalSession_GlobalStatus,
  66. addressingIdentities []string) []*apis.GlobalSession {
  67. contain := func(statuses []apis.GlobalSession_GlobalStatus, status apis.GlobalSession_GlobalStatus) bool {
  68. for _, s := range statuses {
  69. if s == status {
  70. return true
  71. }
  72. }
  73. return false
  74. }
  75. containAddressing := func(addressingIdentities []string, addressing string) bool {
  76. for _, s := range addressingIdentities {
  77. if s == addressing {
  78. return true
  79. }
  80. }
  81. return false
  82. }
  83. var sessions = make([]*apis.GlobalSession, 0)
  84. driver.SessionMap.Range(func(key, value interface{}) bool {
  85. session := value.(*model.GlobalTransaction)
  86. if contain(statuses, session.Status) && containAddressing(addressingIdentities, session.Addressing) {
  87. sessions = append(sessions, session.GlobalSession)
  88. }
  89. return true
  90. })
  91. return sessions
  92. }
  93. // All sessions collection.
  94. func (driver *driver) AllSessions() []*apis.GlobalSession {
  95. var sessions = make([]*apis.GlobalSession, 0)
  96. driver.SessionMap.Range(func(key, value interface{}) bool {
  97. session := value.(*model.GlobalTransaction)
  98. sessions = append(sessions, session.GlobalSession)
  99. return true
  100. })
  101. return sessions
  102. }
  103. // Update global session status.
  104. func (driver *driver) UpdateGlobalSessionStatus(session *apis.GlobalSession, status apis.GlobalSession_GlobalStatus) error {
  105. globalTransaction, ok := driver.SessionMap.Load(session.XID)
  106. if ok {
  107. gt := globalTransaction.(*model.GlobalTransaction)
  108. gt.Status = status
  109. return nil
  110. }
  111. return fmt.Errorf("could not found global transaction xid = %s", session.XID)
  112. }
  113. // Inactive global session.
  114. func (driver *driver) InactiveGlobalSession(session *apis.GlobalSession) error {
  115. globalTransaction, ok := driver.SessionMap.Load(session.XID)
  116. if ok {
  117. gt := globalTransaction.(*model.GlobalTransaction)
  118. gt.Active = false
  119. return nil
  120. }
  121. return fmt.Errorf("could not found global transaction xid = %s", session.XID)
  122. }
  123. // Remove global session.
  124. func (driver *driver) RemoveGlobalSession(session *apis.GlobalSession) error {
  125. driver.SessionMap.Delete(session.XID)
  126. return nil
  127. }
  128. // Add branch session.
  129. func (driver *driver) AddBranchSession(globalSession *apis.GlobalSession, session *apis.BranchSession) error {
  130. globalTransaction, ok := driver.SessionMap.Load(globalSession.XID)
  131. if ok {
  132. gt := globalTransaction.(*model.GlobalTransaction)
  133. gt.BranchSessions[session] = true
  134. return nil
  135. }
  136. return fmt.Errorf("could not found global transaction xid = %s", session.XID)
  137. }
  138. // Find branch session.
  139. func (driver *driver) FindBranchSessions(xid string) []*apis.BranchSession {
  140. globalTransaction, ok := driver.SessionMap.Load(xid)
  141. if ok {
  142. gt := globalTransaction.(*model.GlobalTransaction)
  143. branchSessions := make([]*apis.BranchSession, 0)
  144. for bs := range gt.BranchSessions {
  145. branchSessions = append(branchSessions, bs)
  146. }
  147. return branchSessions
  148. }
  149. return nil
  150. }
  151. // Find branch session.
  152. func (driver *driver) FindBatchBranchSessions(xids []string) []*apis.BranchSession {
  153. branchSessions := make([]*apis.BranchSession, 0)
  154. for i := 0; i < len(xids); i++ {
  155. globalTransaction, ok := driver.SessionMap.Load(xids[i])
  156. if ok {
  157. gt := globalTransaction.(*model.GlobalTransaction)
  158. for bs := range gt.BranchSessions {
  159. branchSessions = append(branchSessions, bs)
  160. }
  161. }
  162. }
  163. return branchSessions
  164. }
  165. // Update branch session status.
  166. func (driver *driver) UpdateBranchSessionStatus(session *apis.BranchSession, status apis.BranchSession_BranchStatus) error {
  167. session.Status = status
  168. return nil
  169. }
  170. // Remove branch session.
  171. func (driver *driver) RemoveBranchSession(globalSession *apis.GlobalSession, session *apis.BranchSession) error {
  172. globalTransaction, ok := driver.SessionMap.Load(globalSession.XID)
  173. if ok {
  174. gt := globalTransaction.(*model.GlobalTransaction)
  175. delete(gt.BranchSessions, session)
  176. return nil
  177. }
  178. return fmt.Errorf("could not found global transaction xid = %s", session.XID)
  179. }
  180. // AcquireLock Acquire lock boolean.
  181. func (driver *driver) AcquireLock(rowLocks []*apis.RowLock) bool {
  182. if rowLocks == nil {
  183. return true
  184. }
  185. for _, rowLock := range rowLocks {
  186. previousLockTransactionID, loaded := driver.LockMap.LoadOrStore(rowLock.RowKey, rowLock.TransactionID)
  187. if loaded {
  188. if previousLockTransactionID == rowLock.TransactionID {
  189. // Locked by me before
  190. continue
  191. } else {
  192. log.Infof("Global rowLock on [%s:%s] is holding by %d", rowLock.TableName, rowLock.PK, previousLockTransactionID)
  193. driver.ReleaseLock(rowLocks)
  194. return false
  195. }
  196. }
  197. }
  198. return true
  199. }
  200. // ReleaseLock Unlock boolean.
  201. func (driver *driver) ReleaseLock(rowLocks []*apis.RowLock) bool {
  202. if rowLocks == nil {
  203. return true
  204. }
  205. for _, rowLock := range rowLocks {
  206. lockedTransactionID, loaded := driver.LockMap.Load(rowLock.RowKey)
  207. if loaded && lockedTransactionID == rowLock.TransactionID {
  208. driver.LockMap.Delete(rowLock.RowKey)
  209. }
  210. }
  211. return true
  212. }
  213. // IsLockable Is lockable boolean.
  214. func (driver *driver) IsLockable(xid string, resourceID string, lockKey string) bool {
  215. rowLocks := storage.CollectRowLocks(lockKey, resourceID, xid)
  216. for _, rowLock := range rowLocks {
  217. lockedTransactionID, loaded := driver.LockMap.Load(rowLock.RowKey)
  218. if loaded && lockedTransactionID != rowLock.TransactionID {
  219. return false
  220. }
  221. }
  222. return true
  223. }