package distlock import ( "context" "fmt" "sync" "time" "gitlink.org.cn/cloudream/common/pkgs/future" "gitlink.org.cn/cloudream/common/pkgs/trie" "gitlink.org.cn/cloudream/common/utils/lo2" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/lockprovider" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/types" ) type AcquireOption struct { Timeout time.Duration } type AcquireOptionFn func(opt *AcquireOption) func WithTimeout(timeout time.Duration) AcquireOptionFn { return func(opt *AcquireOption) { opt.Timeout = timeout } } type Service struct { lock *sync.Mutex provdersTrie *trie.Trie[types.LockProvider] acquirings []*acquireInfo nextReqID int64 } func NewService() *Service { svc := &Service{ lock: &sync.Mutex{}, provdersTrie: trie.NewTrie[types.LockProvider](), } svc.provdersTrie.Create([]any{lockprovider.ShardStoreLockPathPrefix, trie.WORD_ANY}).Value = lockprovider.NewShardStoreLock() return svc } type acquireInfo struct { Request types.LockRequest Callback *future.SetValueFuture[types.RequestID] LastErr error } func (svc *Service) Acquire(req types.LockRequest, opts ...AcquireOptionFn) (*Mutex, error) { var opt = AcquireOption{ Timeout: time.Second * 10, } for _, fn := range opts { fn(&opt) } ctx := context.Background() if opt.Timeout != 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, opt.Timeout) defer cancel() } // 就地检测锁是否可用 svc.lock.Lock() defer svc.lock.Unlock() reqID, err := svc.tryAcquireOne(req) if err != nil { return nil, err } if reqID != "" { return &Mutex{ svc: svc, lockReq: req, lockReqID: reqID, }, nil } // 就地检测失败,那么就需要异步等待锁可用 info := &acquireInfo{ Request: req, Callback: future.NewSetValue[types.RequestID](), } svc.acquirings = append(svc.acquirings, info) // 等待的时候不加锁 svc.lock.Unlock() reqID, err = info.Callback.Wait(ctx) svc.lock.Lock() if err == nil { return &Mutex{ svc: svc, lockReq: req, lockReqID: reqID, }, nil } if err != future.ErrCanceled { lo2.Remove(svc.acquirings, info) return nil, err } // 如果第一次等待是超时错误,那么在锁里再尝试获取一次结果 reqID, err = info.Callback.TryGetValue() if err == nil { return &Mutex{ svc: svc, lockReq: req, lockReqID: reqID, }, nil } lo2.Remove(svc.acquirings, info) return nil, err } func (s *Service) BeginReentrant() *Reentrant { return &Reentrant{ svc: s, } } func (s *Service) release(reqID types.RequestID, req types.LockRequest) { s.lock.Lock() defer s.lock.Unlock() s.releaseRequest(reqID, req) s.tryAcquirings() } func (a *Service) tryAcquirings() { for i := 0; i < len(a.acquirings); i++ { req := a.acquirings[i] reqID, err := a.tryAcquireOne(req.Request) if err != nil { req.LastErr = err continue } req.Callback.SetValue(reqID) a.acquirings[i] = nil } a.acquirings = lo2.RemoveAllDefault(a.acquirings) } func (s *Service) tryAcquireOne(req types.LockRequest) (types.RequestID, error) { err := s.testOneRequest(req) if err != nil { return "", err } reqID := types.RequestID(fmt.Sprintf("%d", s.nextReqID)) s.nextReqID++ s.applyRequest(reqID, req) return reqID, nil } func (s *Service) testOneRequest(req types.LockRequest) error { for _, lock := range req.Locks { n, ok := s.provdersTrie.WalkEnd(lock.Path) if !ok || n.Value == nil { return fmt.Errorf("lock provider not found for path %v", lock.Path) } err := n.Value.CanLock(lock) if err != nil { return err } } return nil } func (s *Service) applyRequest(reqID types.RequestID, req types.LockRequest) { for _, lock := range req.Locks { p, _ := s.provdersTrie.WalkEnd(lock.Path) p.Value.Lock(reqID, lock) } } func (s *Service) releaseRequest(reqID types.RequestID, req types.LockRequest) { for _, lock := range req.Locks { p, _ := s.provdersTrie.WalkEnd(lock.Path) p.Value.Unlock(reqID, lock) } }