diff --git a/pkgs/cmdtrie/command_trie.go b/pkgs/cmdtrie/command_trie.go index e2d3c05..db6c821 100644 --- a/pkgs/cmdtrie/command_trie.go +++ b/pkgs/cmdtrie/command_trie.go @@ -7,7 +7,7 @@ import ( "github.com/samber/lo" "gitlink.org.cn/cloudream/common/pkgs/trie" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) type ExecuteOption struct { @@ -296,7 +296,7 @@ type CommandTrie[TCtx any, TRet any] struct { func NewCommandTrie[TCtx any, TRet any]() CommandTrie[TCtx, TRet] { return CommandTrie[TCtx, TRet]{ - anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), myreflect.TypeOf[TRet]()), + anyTrie: newAnyCommandTrie(reflect2.TypeOf[TCtx](), reflect2.TypeOf[TRet]()), } } @@ -337,7 +337,7 @@ type VoidCommandTrie[TCtx any] struct { func NewVoidCommandTrie[TCtx any]() VoidCommandTrie[TCtx] { return VoidCommandTrie[TCtx]{ - anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), nil), + anyTrie: newAnyCommandTrie(reflect2.TypeOf[TCtx](), nil), } } @@ -368,7 +368,7 @@ type StaticCommandTrie[TRet any] struct { func NewStaticCommandTrie[TRet any]() StaticCommandTrie[TRet] { return StaticCommandTrie[TRet]{ - anyTrie: newAnyCommandTrie(nil, myreflect.TypeOf[TRet]()), + anyTrie: newAnyCommandTrie(nil, reflect2.TypeOf[TRet]()), } } diff --git a/pkgs/logger/global_logger.go b/pkgs/logger/global_logger.go index 793f1fc..3a469fc 100644 --- a/pkgs/logger/global_logger.go +++ b/pkgs/logger/global_logger.go @@ -8,7 +8,7 @@ import ( nested "github.com/antonfisher/nested-logrus-formatter" "github.com/sirupsen/logrus" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) // 输出日志到标准输出。适用于没有设计好日志输出方案时临时使用。 @@ -123,6 +123,6 @@ func WithField(key string, val any) Logger { func WithType[T any](key string) Logger { return &logrusLogger{ - entry: logrus.WithField(key, myreflect.TypeOf[T]().Name()), + entry: logrus.WithField(key, reflect2.TypeOf[T]().Name()), } } diff --git a/pkgs/mq/client.go b/pkgs/mq/client.go index 8709347..db68963 100644 --- a/pkgs/mq/client.go +++ b/pkgs/mq/client.go @@ -10,7 +10,7 @@ import ( "github.com/streadway/amqp" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/common/pkgs/logger" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) const ( @@ -323,8 +323,8 @@ func Request[TSvc any, TReq MessageBody, TResp MessageBody](_ func(svc TSvc, msg respBody, ok := resp.Body.(TResp) if !ok { return defRet, fmt.Errorf("expect a %s body, but got %s", - myreflect.ElemTypeOf[TResp]().Name(), - myreflect.TypeOfValue(resp.Body).Name()) + reflect2.ElemTypeOf[TResp]().Name(), + reflect2.TypeOfValue(resp.Body).Name()) } return respBody, nil diff --git a/pkgs/mq/message.go b/pkgs/mq/message.go index 36a7ad3..452716a 100644 --- a/pkgs/mq/message.go +++ b/pkgs/mq/message.go @@ -2,7 +2,7 @@ package mq import ( "gitlink.org.cn/cloudream/common/pkgs/types" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" "gitlink.org.cn/cloudream/common/utils/serder" ) @@ -83,7 +83,7 @@ var msgBodyTypeUnion = serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTy // 所有新定义的Message都需要在init中调用此函数 func RegisterMessage[T MessageBody]() { - err := msgBodyTypeUnion.Add(myreflect.TypeOf[T]()) + err := msgBodyTypeUnion.Add(reflect2.TypeOf[T]()) if err != nil { panic(err) } diff --git a/pkgs/mq/message_dispatcher.go b/pkgs/mq/message_dispatcher.go index 666e872..629fbf7 100644 --- a/pkgs/mq/message_dispatcher.go +++ b/pkgs/mq/message_dispatcher.go @@ -3,27 +3,27 @@ package mq import ( "fmt" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) type HandlerFn func(svcBase any, msg *Message) (*Message, error) type MessageDispatcher struct { - Handlers map[myreflect.Type]HandlerFn + Handlers map[reflect2.Type]HandlerFn } func NewMessageDispatcher() MessageDispatcher { return MessageDispatcher{ - Handlers: make(map[myreflect.Type]HandlerFn), + Handlers: make(map[reflect2.Type]HandlerFn), } } -func (h *MessageDispatcher) Add(typ myreflect.Type, handler HandlerFn) { +func (h *MessageDispatcher) Add(typ reflect2.Type, handler HandlerFn) { h.Handlers[typ] = handler } func (h *MessageDispatcher) Handle(svcBase any, msg *Message) (*Message, error) { - typ := myreflect.TypeOfValue(msg.Body) + typ := reflect2.TypeOfValue(msg.Body) fn, ok := h.Handlers[typ] if !ok { return nil, fmt.Errorf("unsupported message type: %s", typ.String()) @@ -34,7 +34,7 @@ func (h *MessageDispatcher) Handle(svcBase any, msg *Message) (*Message, error) // 将Service中的一个接口函数作为指定类型消息的处理函数 func AddServiceFn[TSvc any, TReq MessageBody, TResp MessageBody](dispatcher *MessageDispatcher, svcFn func(svc TSvc, msg TReq) (TResp, *CodeMessage)) { - dispatcher.Add(myreflect.TypeOf[TReq](), func(svcBase any, reqMsg *Message) (*Message, error) { + dispatcher.Add(reflect2.TypeOf[TReq](), func(svcBase any, reqMsg *Message) (*Message, error) { reqMsgBody := reqMsg.Body.(TReq) ret, codeMsg := svcFn(svcBase.(TSvc), reqMsgBody) @@ -47,7 +47,7 @@ func AddServiceFn[TSvc any, TReq MessageBody, TResp MessageBody](dispatcher *Mes // 将Service中的一个*没有返回值的*接口函数作为指定类型消息的处理函数 func AddNoRespServiceFn[TSvc any, TReq MessageBody](dispatcher *MessageDispatcher, svcFn func(svc TSvc, msg TReq)) { - dispatcher.Add(myreflect.TypeOf[TReq](), func(svcBase any, reqMsg *Message) (*Message, error) { + dispatcher.Add(reflect2.TypeOf[TReq](), func(svcBase any, reqMsg *Message) (*Message, error) { reqMsgBody := reqMsg.Body.(TReq) svcFn(svcBase.(TSvc), reqMsgBody) diff --git a/pkgs/mq/message_test.go b/pkgs/mq/message_test.go index b2482fa..03be582 100644 --- a/pkgs/mq/message_test.go +++ b/pkgs/mq/message_test.go @@ -9,7 +9,7 @@ import ( jsoniter "github.com/json-iterator/go" . "github.com/smartystreets/goconvey/convey" "gitlink.org.cn/cloudream/common/pkgs/types" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" "gitlink.org.cn/cloudream/common/utils/serder" ) @@ -42,13 +42,13 @@ func TestMessage(t *testing.T) { Nil: nil, } - jsoniter.RegisterTypeEncoderFunc(myreflect.TypeOf[MyAny]().String(), + jsoniter.RegisterTypeEncoderFunc(reflect2.TypeOf[MyAny]().String(), func(ptr unsafe.Pointer, stream *jsoniter.Stream) { val := *((*MyAny)(ptr)) stream.WriteArrayStart() if val != nil { - stream.WriteString(myreflect.TypeOfValue(val).String()) + stream.WriteString(reflect2.TypeOfValue(val).String()) stream.WriteRaw(",") stream.WriteVal(val) } @@ -58,7 +58,7 @@ func TestMessage(t *testing.T) { return false }) - jsoniter.RegisterTypeDecoderFunc(myreflect.TypeOf[MyAny]().String(), + jsoniter.RegisterTypeDecoderFunc(reflect2.TypeOf[MyAny]().String(), func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { vp := (*MyAny)(ptr) diff --git a/pkgs/typedispatcher/type_dispatcher.go b/pkgs/typedispatcher/type_dispatcher.go index 4270244..57f046a 100644 --- a/pkgs/typedispatcher/type_dispatcher.go +++ b/pkgs/typedispatcher/type_dispatcher.go @@ -1,28 +1,28 @@ package typedispatcher import ( - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) type HandlerFn[TRet any] func(val any) TRet type TypeDispatcher[TRet any] struct { - handlers map[myreflect.Type]HandlerFn[TRet] + handlers map[reflect2.Type]HandlerFn[TRet] } func NewTypeDispatcher[TRet any]() TypeDispatcher[TRet] { return TypeDispatcher[TRet]{ - handlers: make(map[myreflect.Type]HandlerFn[TRet]), + handlers: make(map[reflect2.Type]HandlerFn[TRet]), } } -func (t *TypeDispatcher[TRet]) Add(typ myreflect.Type, fn HandlerFn[TRet]) { +func (t *TypeDispatcher[TRet]) Add(typ reflect2.Type, fn HandlerFn[TRet]) { t.handlers[typ] = fn } func (t *TypeDispatcher[TRet]) Dispatch(val any) (TRet, bool) { var ret TRet - typ := myreflect.TypeOfValue(val) + typ := reflect2.TypeOfValue(val) handler, ok := t.handlers[typ] if !ok { return ret, false @@ -32,7 +32,7 @@ func (t *TypeDispatcher[TRet]) Dispatch(val any) (TRet, bool) { } func Add[T any, TRet any](dispatcher TypeDispatcher[TRet], handler func(val T) TRet) { - dispatcher.Add(myreflect.TypeOf[T](), func(val any) TRet { + dispatcher.Add(reflect2.TypeOf[T](), func(val any) TRet { return handler(val.(T)) }) } diff --git a/pkgs/types/union.go b/pkgs/types/union.go index 639e25c..8d0d73c 100644 --- a/pkgs/types/union.go +++ b/pkgs/types/union.go @@ -4,17 +4,17 @@ import ( "fmt" "reflect" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) type AnyTypeUnion struct { // 这个集合的类型 - UnionType myreflect.Type + UnionType reflect2.Type // 集合中包含的类型,即遇到UnionType类型的值时,它内部的实际类型的范围 - ElementTypes []myreflect.Type + ElementTypes []reflect2.Type } -func (u *AnyTypeUnion) Include(typ myreflect.Type) bool { +func (u *AnyTypeUnion) Include(typ reflect2.Type) bool { for _, t := range u.ElementTypes { if t == typ { return true @@ -24,7 +24,7 @@ func (u *AnyTypeUnion) Include(typ myreflect.Type) bool { return false } -func (u *AnyTypeUnion) Add(typ myreflect.Type) error { +func (u *AnyTypeUnion) Add(typ reflect2.Type) error { if !typ.AssignableTo(u.UnionType) { return fmt.Errorf("type is not assignable to union type") } @@ -55,7 +55,7 @@ func NewTypeUnion[TU any](eleValues ...TU) TypeUnion[TU] { return TypeUnion[TU]{ AnyTypeUnion{ - UnionType: myreflect.TypeOf[TU](), + UnionType: reflect2.TypeOf[TU](), ElementTypes: eleTypes, }, } diff --git a/sdks/pcm/models.go b/sdks/pcm/models.go index 0b6acb3..a225375 100644 --- a/sdks/pcm/models.go +++ b/sdks/pcm/models.go @@ -34,5 +34,5 @@ const ( TaskStatusPending TaskStatus = "Pending" TaskStatusRunning TaskStatus = "Running" TaskStatusSuccess TaskStatus = "succeeded" - TaskStatuFailed TaskStatus = "failed" + TaskStatusFailed TaskStatus = "failed" ) diff --git a/sdks/scheduler/models.go b/sdks/scheduler/models.go index 2748c55..5ad60b4 100644 --- a/sdks/scheduler/models.go +++ b/sdks/scheduler/models.go @@ -35,7 +35,7 @@ type JobInfo interface { var JobInfoTypeUnion = types.NewTypeUnion[JobInfo]( (*NormalJobInfo)(nil), - (*ResourceJobInfo)(nil), + (*DataReturnJobInfo)(nil), ) var _ = serder.UseTypeUnionInternallyTagged(&JobInfoTypeUnion, "type") @@ -57,8 +57,8 @@ type NormalJobInfo struct { Services JobServicesInfo `json:"services"` } -type ResourceJobInfo struct { - serder.Metadata `union:"Resource"` +type DataReturnJobInfo struct { + serder.Metadata `union:"DataReturn"` JobInfoBase Type string `json:"type"` BucketID cdssdk.BucketID `json:"bucketID"` diff --git a/sdks/scheduler/scheduler_test.go b/sdks/scheduler/scheduler_test.go index 25a2977..9b22a18 100644 --- a/sdks/scheduler/scheduler_test.go +++ b/sdks/scheduler/scheduler_test.go @@ -15,7 +15,7 @@ func Test_JobSet(t *testing.T) { id, err := cli.JobSetSumbit(JobSetSumbitReq{ JobSetInfo: JobSetInfo{ Jobs: []JobInfo{ - &ResourceJobInfo{ + &DataReturnJobInfo{ Type: JobTypeResource, }, &NormalJobInfo{ diff --git a/sdks/storage/package.go b/sdks/storage/package.go index 6655e48..22b6aa2 100644 --- a/sdks/storage/package.go +++ b/sdks/storage/package.go @@ -217,7 +217,7 @@ func (c *PackageService) GetCachedNodes(req PackageGetCachedNodesReq) (*PackageG return nil, err } resp, err := myhttp.GetJSON(url, myhttp.RequestParam{ - Body: req, + Query: req, }) if err != nil { return nil, err @@ -252,7 +252,7 @@ func (c *PackageService) GetLoadedNodes(req PackageGetLoadedNodesReq) (*PackageG return nil, err } resp, err := myhttp.GetJSON(url, myhttp.RequestParam{ - Body: req, + Query: req, }) if err != nil { return nil, err diff --git a/utils/reflect/reflect.go b/utils/reflect2/reflect.go similarity index 96% rename from utils/reflect/reflect.go rename to utils/reflect2/reflect.go index cf98e68..92237c2 100644 --- a/utils/reflect/reflect.go +++ b/utils/reflect2/reflect.go @@ -1,4 +1,4 @@ -package reflect +package reflect2 import "reflect" diff --git a/utils/serder/any_to_any.go b/utils/serder/any_to_any.go index 8fe3889..85ea4fe 100644 --- a/utils/serder/any_to_any.go +++ b/utils/serder/any_to_any.go @@ -4,7 +4,7 @@ import ( "reflect" mp "github.com/mitchellh/mapstructure" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) type Converter func(from reflect.Value, to reflect.Value) (interface{}, error) @@ -68,11 +68,11 @@ func AnyToAny(src any, dst any, opts ...AnyToAnyOption) error { // fromAny 如果目的字段实现的FromAny接口,那么通过此接口实现字段类型转换 func fromAny(srcType reflect.Type, targetType reflect.Type, data interface{}) (interface{}, error) { - if myreflect.TypeOfValue(data) == targetType { + if reflect2.TypeOfValue(data) == targetType { return data, nil } - if targetType.Implements(myreflect.TypeOf[FromAny]()) { + if targetType.Implements(reflect2.TypeOf[FromAny]()) { // 非pointer receiver的FromAny没有意义,因为修改不了receiver的内容,所以这里只支持指针类型 if targetType.Kind() == reflect.Pointer { val := reflect.New(targetType.Elem()) @@ -88,7 +88,7 @@ func fromAny(srcType reflect.Type, targetType reflect.Type, data interface{}) (i return val.Interface(), nil } - } else if reflect.PointerTo(targetType).Implements(myreflect.TypeOf[FromAny]()) { + } else if reflect.PointerTo(targetType).Implements(reflect2.TypeOf[FromAny]()) { val := reflect.New(targetType) anyIf := val.Interface().(FromAny) ok, err := anyIf.FromAny(data) @@ -107,12 +107,12 @@ func fromAny(srcType reflect.Type, targetType reflect.Type, data interface{}) (i // 如果源字段实现了ToAny接口,那么通过此接口实现字段类型转换 func toAny(srcType reflect.Type, targetType reflect.Type, data interface{}) (interface{}, error) { - dataType := myreflect.TypeOfValue(data) + dataType := reflect2.TypeOfValue(data) if dataType == targetType { return data, nil } - if dataType.Implements(myreflect.TypeOf[ToAny]()) { + if dataType.Implements(reflect2.TypeOf[ToAny]()) { anyIf := data.(ToAny) dstVal, ok, err := anyIf.ToAny(targetType) if err != nil { @@ -123,7 +123,7 @@ func toAny(srcType reflect.Type, targetType reflect.Type, data interface{}) (int } return dstVal, nil - } else if reflect.PointerTo(dataType).Implements(myreflect.TypeOf[ToAny]()) { + } else if reflect.PointerTo(dataType).Implements(reflect2.TypeOf[ToAny]()) { dataVal := reflect.ValueOf(data) dataPtrVal := reflect.New(dataType) diff --git a/utils/serder/serder_test.go b/utils/serder/serder_test.go index e2c4b8c..4914531 100644 --- a/utils/serder/serder_test.go +++ b/utils/serder/serder_test.go @@ -7,7 +7,7 @@ import ( . "github.com/smartystreets/goconvey/convey" "gitlink.org.cn/cloudream/common/pkgs/types" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) type FromAnyString struct { @@ -28,7 +28,7 @@ type ToAnyString struct { } func (a *ToAnyString) ToAny(typ reflect.Type) (val any, ok bool, err error) { - if typ == myreflect.TypeOf[map[string]any]() { + if typ == reflect2.TypeOf[map[string]any]() { return map[string]any{ "str": "@" + a.Str, }, true, nil @@ -55,7 +55,7 @@ type ToAnySt struct { } func (a *ToAnySt) ToAny(typ reflect.Type) (val any, ok bool, err error) { - if typ == myreflect.TypeOf[FromAnySt]() { + if typ == reflect2.TypeOf[FromAnySt]() { return FromAnySt{ Value: "To:" + a.Value, }, true, nil @@ -69,7 +69,7 @@ type DirToAnySt struct { } func (a DirToAnySt) ToAny(typ reflect.Type) (val any, ok bool, err error) { - if typ == myreflect.TypeOf[FromAnySt]() { + if typ == reflect2.TypeOf[FromAnySt]() { return FromAnySt{ Value: "DirTo:" + a.Value, }, true, nil @@ -181,7 +181,7 @@ func Test_AnyToAny(t *testing.T) { err := AnyToAny(st1, &st2, AnyToAnyOption{ Converters: []Converter{func(from reflect.Value, to reflect.Value) (interface{}, error) { - if from.Type() == myreflect.TypeOf[Struct1]() && to.Type() == myreflect.TypeOf[Struct2]() { + if from.Type() == reflect2.TypeOf[Struct1]() && to.Type() == reflect2.TypeOf[Struct2]() { s1 := from.Interface().(Struct1) return Struct2{ Value: "@" + s1.Value, diff --git a/utils/serder/union_handler.go b/utils/serder/union_handler.go index bef1dd0..ce4ac3e 100644 --- a/utils/serder/union_handler.go +++ b/utils/serder/union_handler.go @@ -9,7 +9,7 @@ import ( "github.com/modern-go/reflect2" "gitlink.org.cn/cloudream/common/pkgs/types" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + ref2 "gitlink.org.cn/cloudream/common/utils/reflect2" ) type anyTypeUnionExternallyTagged struct { @@ -106,14 +106,14 @@ func (u *TypeUnionInternallyTagged[T]) Add(typ reflect.Type) error { } // 要求内嵌Metadata结构体,那么结构体中的字段名就会是Metadata, - field, ok := structType.FieldByName(myreflect.TypeNameOf[Metadata]()) + field, ok := structType.FieldByName(ref2.TypeNameOf[Metadata]()) if !ok { u.TagToType[makeDerefFullTypeName(structType)] = typ return nil } // 为防同名,检查类型是不是也是Metadata - if field.Type != myreflect.TypeOf[Metadata]() { + if field.Type != ref2.TypeOf[Metadata]() { u.TagToType[makeDerefFullTypeName(structType)] = typ return nil } @@ -293,7 +293,7 @@ func (e *ExternallyTaggedEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.St } stream.WriteObjectStart() - valType := myreflect.TypeOfValue(val) + valType := ref2.TypeOfValue(val) if !e.union.Union.Include(valType) { stream.Error = fmt.Errorf("type %v is not in union %v", valType, e.union.Union.UnionType) return diff --git a/utils/serder/walk_test.go b/utils/serder/walk_test.go index 91f0671..de926ed 100644 --- a/utils/serder/walk_test.go +++ b/utils/serder/walk_test.go @@ -6,7 +6,7 @@ import ( "testing" . "github.com/smartystreets/goconvey/convey" - myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/reflect2" ) func Test_WalkValue(t *testing.T) { @@ -38,8 +38,8 @@ func Test_WalkValue(t *testing.T) { isBaseDataType := func(val reflect.Value) bool { typ := val.Type() - return typ == myreflect.TypeOf[int]() || typ == myreflect.TypeOf[bool]() || - typ == myreflect.TypeOf[string]() || typ == myreflect.TypeOf[float32]() || val.IsZero() + return typ == reflect2.TypeOf[int]() || typ == reflect2.TypeOf[bool]() || + typ == reflect2.TypeOf[string]() || typ == reflect2.TypeOf[float32]() || val.IsZero() } toString := func(val any) string { diff --git a/utils/sync2/channel.go b/utils/sync2/channel.go new file mode 100644 index 0000000..f279e09 --- /dev/null +++ b/utils/sync2/channel.go @@ -0,0 +1,77 @@ +package sync2 + +import ( + "context" + "errors" + "sync" +) + +var ErrChannelClosed = errors.New("channel is closed") + +type Channel[T any] struct { + ch chan T + closed chan any + closeOnce sync.Once + err error +} + +func NewChannel[T any]() *Channel[T] { + return &Channel[T]{ + ch: make(chan T), + closed: make(chan any), + } +} + +func (c *Channel[T]) Error() error { + return c.err +} + +func (c *Channel[T]) Send(val T) error { + select { + case c.ch <- val: + return nil + case <-c.closed: + return c.err + } +} + +func (c *Channel[T]) Receive(ctx context.Context) (T, error) { + select { + case val := <-c.ch: + return val, nil + case <-c.closed: + var t T + return t, c.err + case <-ctx.Done(): + var t T + return t, ctx.Err() + } +} + +func (c *Channel[T]) Sender() chan<- T { + return c.ch +} + +func (c *Channel[T]) Receiver() <-chan T { + return c.ch +} + +func (c *Channel[T]) Close() { + c.closeOnce.Do(func() { + close(c.closed) + close(c.ch) + c.err = ErrChannelClosed + }) +} + +func (c *Channel[T]) CloseWithError(err error) { + c.closeOnce.Do(func() { + close(c.closed) + close(c.ch) + c.err = err + }) +} + +func (c *Channel[T]) Closed() <-chan any { + return c.closed +} diff --git a/utils/sync2/select_set.go b/utils/sync2/select_set.go new file mode 100644 index 0000000..592c28b --- /dev/null +++ b/utils/sync2/select_set.go @@ -0,0 +1,48 @@ +package sync2 + +import ( + "reflect" + + "gitlink.org.cn/cloudream/common/utils/lo2" +) + +type SelectCase int + +type SelectSet[T any, C any] struct { + cases []reflect.SelectCase + tags []T +} + +func (s *SelectSet[T, C]) Add(tag T, ch <-chan C) SelectCase { + s.cases = append(s.cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}) + s.tags = append(s.tags, tag) + + return SelectCase(len(s.cases) - 1) +} + +func (s *SelectSet[T, C]) AddDefault(tag T, ch <-chan C) SelectCase { + s.cases = append(s.cases, reflect.SelectCase{Dir: reflect.SelectDefault, Chan: reflect.ValueOf(ch)}) + s.tags = append(s.tags, tag) + + return SelectCase(len(s.cases) - 1) +} + +func (s *SelectSet[T, C]) Remove(caze SelectCase) { + s.cases = lo2.RemoveAt(s.cases, int(caze)) + s.tags = lo2.RemoveAt(s.tags, int(caze)) +} + +func (s *SelectSet[T, C]) Select() (T, C, bool) { + chosen, recv, ok := reflect.Select(s.cases) + if !ok { + var t T + var c C + return t, c, false + } + + return s.tags[chosen], recv.Interface().(C), true +} + +func (s *SelectSet[T, C]) Count() int { + return len(s.cases) +}