package rpc import ( "fmt" "io" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/common/utils/io2" "gitlink.org.cn/cloudream/common/utils/serder" "gitlink.org.cn/cloudream/jcs-pub/common/ecode" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func UnaryClient[Resp, Req any](apiFn func(context.Context, *Request, ...grpc.CallOption) (*Response, error), ctx context.Context, req Req) (Resp, *CodeError) { data, err := serder.ObjectToJSONEx(req) if err != nil { var resp Resp return resp, Failed(errorcode.OperationFailed, err.Error()) } resp, err := apiFn(ctx, &Request{ Payload: data, }) if err != nil { var resp Resp return resp, ExtractCodeError(err) } ret, err := serder.JSONToObjectEx[Resp](resp.Payload) if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } return ret, nil } func UnaryServer[Resp, Req any](apiFn func(context.Context, Req) (Resp, *CodeError), ctx context.Context, req *Request) (*Response, error) { rreq, err := serder.JSONToObjectEx[Req](req.Payload) if err != nil { return nil, MakeCodeError(errorcode.OperationFailed, err.Error()) } ret, cerr := apiFn(ctx, rreq) if cerr != nil { return nil, WrapCodeError(cerr) } data, err := serder.ObjectToJSONEx(ret) if err != nil { return nil, MakeCodeError(errorcode.OperationFailed, err.Error()) } return &Response{ Payload: data, }, nil } type UploadStreamAPIClient interface { GRPCChunkedWriter CloseAndRecv() (*Response, error) } type UploadStreamAPIServer interface { GRPCChunkedReader SendAndClose(*Response) error Context() context.Context } type UploadStreamReq interface { GetStream() io.Reader SetStream(io.Reader) } // 封装了上传流API的客户端逻辑。记得将Req里的Stream字段设置为不需要序列化(json:"-") func UploadStreamClient[Resp any, Req UploadStreamReq, APIRet UploadStreamAPIClient](apiFn func(context.Context, ...grpc.CallOption) (APIRet, error), ctx context.Context, req Req) (Resp, *CodeError) { stream := req.GetStream() var ret Resp data, err := serder.ObjectToJSONEx(req) if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } ctx2, cancelFn := context.WithCancel(ctx) defer cancelFn() cli, err := apiFn(ctx2) if err != nil { return ret, ExtractCodeError(err) } cw := NewChunkedWriter(cli) err = cw.WriteDataPart("", data) if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } _, err = cw.WriteStreamPart("", stream) if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } err = cw.Finish() if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } resp, err := cli.CloseAndRecv() if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } ret, err = serder.JSONToObjectEx[Resp](resp.Payload) if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } return ret, nil } func UploadStreamServer[Resp any, Req UploadStreamReq, APIRet UploadStreamAPIServer](apiFn func(context.Context, Req) (Resp, *CodeError), req APIRet) error { cr := NewChunkedReader(req) _, data, err := cr.NextDataPart() if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } _, pr, err := cr.NextPart() if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } rreq, err := serder.JSONToObjectEx[Req](data) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } rreq.SetStream(pr) resp, cerr := apiFn(req.Context(), rreq) if cerr != nil { return WrapCodeError(cerr) } respData, err := serder.ObjectToJSONEx(resp) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } err = req.SendAndClose(&Response{Payload: respData}) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } return nil } type DownloadStreamAPIClient interface { GRPCChunkedReader } type DownloadStreamAPIServer interface { GRPCChunkedWriter Context() context.Context } type DownloadStreamResp interface { GetStream() io.ReadCloser SetStream(io.ReadCloser) } // 封装了下载流API的客户端逻辑。记得将Resp里的Stream字段设置为不需要序列化(json:"-") func DownloadStreamClient[Resp DownloadStreamResp, Req any, APIRet DownloadStreamAPIClient](apiFn func(context.Context, *Request, ...grpc.CallOption) (APIRet, error), ctx context.Context, req Req) (Resp, *CodeError) { var ret Resp data, err := serder.ObjectToJSONEx(req) if err != nil { return ret, Failed(errorcode.OperationFailed, err.Error()) } ctx2, cancelFn := context.WithCancel(ctx) cli, err := apiFn(ctx2, &Request{Payload: data}) if err != nil { cancelFn() return ret, ExtractCodeError(err) } cr := NewChunkedReader(cli) _, data, err = cr.NextDataPart() if err != nil { cancelFn() return ret, Failed(errorcode.OperationFailed, err.Error()) } resp, err := serder.JSONToObjectEx[Resp](data) if err != nil { cancelFn() return ret, Failed(errorcode.OperationFailed, err.Error()) } _, pr, err := cr.NextPart() if err != nil { cancelFn() return ret, Failed(errorcode.OperationFailed, err.Error()) } resp.SetStream(io2.DelegateReadCloser(pr, func() error { cancelFn() return nil })) return resp, nil } func DownloadStreamServer[Resp DownloadStreamResp, Req any, APIRet DownloadStreamAPIServer](apiFn func(context.Context, Req) (Resp, *CodeError), req *Request, ret APIRet) error { rreq, err := serder.JSONToObjectEx[Req](req.Payload) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } resp, cerr := apiFn(ret.Context(), rreq) if cerr != nil { return WrapCodeError(cerr) } defer resp.GetStream().Close() cw := NewChunkedWriter(ret) data, err := serder.ObjectToJSONEx(resp) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } err = cw.WriteDataPart("", data) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } _, err = cw.WriteStreamPart("", resp.GetStream()) if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } err = cw.Finish() if err != nil { return MakeCodeError(errorcode.OperationFailed, err.Error()) } return nil } type BidChannelAPIClient interface { Send(*Request) error Recv() (*Response, error) grpc.ClientStream } func BidChannelClient[Resp any, Req any, APIRet BidChannelAPIClient](apiFn func(context.Context, ...grpc.CallOption) (APIRet, error), ctx context.Context) BidChan[Resp, Req] { ctx, cancelFn := context.WithCancel(ctx) ret, err := apiFn(ctx) if err != nil { return NewFusedChan[Resp, Req](Failed(errorcode.OperationFailed, err.Error())) } return NewBidChanClient[Resp, Req](ret, cancelFn) } type BidChannelAPIServer interface { Send(*Response) error Recv() (*Request, error) grpc.ServerStream } func BidChannelServer[Resp any, Req any, APIArg BidChannelAPIServer](apiFn func(BidChan[Req, Resp]), arg APIArg) error { errCh := make(chan *CodeError, 1) ch := NewBidChanServer[Req, Resp](arg, errCh) go apiFn(ch) cerr := <-errCh if cerr != nil { return WrapCodeError(cerr) } return nil } func Failed(errCode ecode.ErrorCode, format string, args ...any) *CodeError { return &CodeError{ Code: string(errCode), Message: fmt.Sprintf(format, args...), } } // 定义一个额外的结构体,防止陷入 (*CodeError)(nil) != nil 的陷阱 type ErrorCodeError struct { CE *CodeError } func (c *ErrorCodeError) Error() string { return fmt.Sprintf("code: %s, message: %s", c.CE.Code, c.CE.Message) } func (c *CodeError) ToError() error { if c == nil { return nil } return &ErrorCodeError{CE: c} } func ExtractCodeError(err error) *CodeError { status, ok := status.FromError(err) if ok { dts := status.Details() if len(dts) > 0 { ce, ok := dts[0].(*CodeError) if ok { return ce } } } return Failed(errorcode.OperationFailed, err.Error()) } func MakeCodeError(code ecode.ErrorCode, msg string) error { ce, _ := status.New(codes.Unknown, "custom error").WithDetails(Failed(code, msg)) return ce.Err() } func WrapCodeError(ce *CodeError) error { e, _ := status.New(codes.Unknown, "custom error").WithDetails(ce) return e.Err() }