|
- package distlock
-
- import (
- "context"
- "fmt"
- "sync"
- "time"
-
- "gitlink.org.cn/cloudream/common/pkgs/future"
- "gitlink.org.cn/cloudream/common/pkgs/trie"
- "gitlink.org.cn/cloudream/common/utils/lo2"
- "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/lockprovider"
- "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/types"
- )
-
- type AcquireOption struct {
- Timeout time.Duration
- }
-
- type AcquireOptionFn func(opt *AcquireOption)
-
- func WithTimeout(timeout time.Duration) AcquireOptionFn {
- return func(opt *AcquireOption) {
- opt.Timeout = timeout
- }
- }
-
- type Service struct {
- lock *sync.Mutex
- provdersTrie *trie.Trie[types.LockProvider]
- acquirings []*acquireInfo
- nextReqID int64
- }
-
- func NewService() *Service {
- svc := &Service{
- lock: &sync.Mutex{},
- provdersTrie: trie.NewTrie[types.LockProvider](),
- }
-
- svc.provdersTrie.Create([]any{lockprovider.ShardStoreLockPathPrefix, trie.WORD_ANY}).Value = lockprovider.NewShardStoreLock()
- return svc
- }
-
- type acquireInfo struct {
- Request types.LockRequest
- Callback *future.SetValueFuture[types.RequestID]
- LastErr error
- }
-
- func (svc *Service) Acquire(req types.LockRequest, opts ...AcquireOptionFn) (*Mutex, error) {
- var opt = AcquireOption{
- Timeout: time.Second * 10,
- }
- for _, fn := range opts {
- fn(&opt)
- }
-
- ctx := context.Background()
- if opt.Timeout != 0 {
- var cancel func()
- ctx, cancel = context.WithTimeout(ctx, opt.Timeout)
- defer cancel()
- }
-
- // 就地检测锁是否可用
- svc.lock.Lock()
- defer svc.lock.Unlock()
-
- reqID, err := svc.tryAcquireOne(req)
- if err != nil {
- return nil, err
- }
-
- if reqID != "" {
- return &Mutex{
- svc: svc,
- lockReq: req,
- lockReqID: reqID,
- }, nil
- }
-
- // 就地检测失败,那么就需要异步等待锁可用
- info := &acquireInfo{
- Request: req,
- Callback: future.NewSetValue[types.RequestID](),
- }
- svc.acquirings = append(svc.acquirings, info)
-
- // 等待的时候不加锁
- svc.lock.Unlock()
- reqID, err = info.Callback.Wait(ctx)
- svc.lock.Lock()
-
- if err == nil {
- return &Mutex{
- svc: svc,
- lockReq: req,
- lockReqID: reqID,
- }, nil
- }
-
- if err != future.ErrCanceled {
- lo2.Remove(svc.acquirings, info)
- return nil, err
- }
-
- // 如果第一次等待是超时错误,那么在锁里再尝试获取一次结果
- reqID, err = info.Callback.TryGetValue()
- if err == nil {
- return &Mutex{
- svc: svc,
- lockReq: req,
- lockReqID: reqID,
- }, nil
- }
-
- lo2.Remove(svc.acquirings, info)
- return nil, err
- }
-
- func (s *Service) BeginReentrant() *Reentrant {
- return &Reentrant{
- svc: s,
- }
- }
-
- func (s *Service) release(reqID types.RequestID, req types.LockRequest) {
- s.lock.Lock()
- defer s.lock.Unlock()
-
- s.releaseRequest(reqID, req)
- s.tryAcquirings()
- }
-
- func (a *Service) tryAcquirings() {
- for i := 0; i < len(a.acquirings); i++ {
- req := a.acquirings[i]
-
- reqID, err := a.tryAcquireOne(req.Request)
- if err != nil {
- req.LastErr = err
- continue
- }
-
- req.Callback.SetValue(reqID)
- a.acquirings[i] = nil
- }
-
- a.acquirings = lo2.RemoveAllDefault(a.acquirings)
- }
-
- func (s *Service) tryAcquireOne(req types.LockRequest) (types.RequestID, error) {
- err := s.testOneRequest(req)
- if err != nil {
- return "", err
- }
-
- reqID := types.RequestID(fmt.Sprintf("%d", s.nextReqID))
- s.nextReqID++
-
- s.applyRequest(reqID, req)
- return reqID, nil
- }
-
- func (s *Service) testOneRequest(req types.LockRequest) error {
- for _, lock := range req.Locks {
- n, ok := s.provdersTrie.WalkEnd(lock.Path)
- if !ok || n.Value == nil {
- return fmt.Errorf("lock provider not found for path %v", lock.Path)
- }
-
- err := n.Value.CanLock(lock)
- if err != nil {
- return err
- }
- }
-
- return nil
- }
-
- func (s *Service) applyRequest(reqID types.RequestID, req types.LockRequest) {
- for _, lock := range req.Locks {
- p, _ := s.provdersTrie.WalkEnd(lock.Path)
- p.Value.Lock(reqID, lock)
- }
- }
-
- func (s *Service) releaseRequest(reqID types.RequestID, req types.LockRequest) {
- for _, lock := range req.Locks {
- p, _ := s.provdersTrie.WalkEnd(lock.Path)
- p.Value.Unlock(reqID, lock)
- }
- }
|