@@ -6,6 +6,10 @@ require ( | |||
github.com/BurntSushi/toml v1.1.0 // indirect | |||
github.com/apache/dubbo-getty v1.4.8 | |||
github.com/dubbogo/gost v1.11.23 | |||
github.com/dubbogo/tools v1.0.9 // indirect | |||
github.com/fagongzi/goetty v1.3.1 | |||
github.com/fagongzi/log v0.0.0-20170831135209-9a647df25e0e | |||
github.com/fagongzi/util v0.0.0-20181102105153-fd38e0f42a4f | |||
github.com/golang/snappy v0.0.4 // indirect | |||
github.com/natefinch/lumberjack v2.0.0+incompatible | |||
github.com/pkg/errors v0.9.1 | |||
@@ -1,7 +1,16 @@ | |||
package error | |||
import ( | |||
"github.com/pkg/errors" | |||
) | |||
type ErrorCode int32 | |||
const ( | |||
ErrorCode_IllegalState ErrorCode = 40001 | |||
) | |||
var ( | |||
Error_TooManySessions = errors.New("too many seeessions") | |||
Error_HeartBeatTimeOut = errors.New("heart beat time out") | |||
) |
@@ -34,6 +34,7 @@ func GetDefaultGettyConfig() GettyConfig { | |||
TCPReadTimeout: time.Second, | |||
TCPWriteTimeout: 5 * time.Second, | |||
WaitTimeout: time.Second, | |||
CronPeriod: time.Second, | |||
MaxMsgLen: 4096, | |||
SessionName: "rpc_client", | |||
}, | |||
@@ -46,6 +47,7 @@ type GettySessionParam struct { | |||
TCPNoDelay bool `default:"true" yaml:"tcp_no_delay" json:"tcp_no_delay,omitempty"` | |||
TCPKeepAlive bool `default:"true" yaml:"tcp_keep_alive" json:"tcp_keep_alive,omitempty"` | |||
KeepAlivePeriod time.Duration `default:"180s" yaml:"keep_alive_period" json:"keep_alive_period,omitempty"` | |||
CronPeriod time.Duration `default:"1s" yaml:"keep_alive_period" json:"keep_alive_period,omitempty"` | |||
TCPRBufSize int `default:"262144" yaml:"tcp_r_buf_size" json:"tcp_r_buf_size,omitempty"` | |||
TCPWBufSize int `default:"65536" yaml:"tcp_w_buf_size" json:"tcp_w_buf_size,omitempty"` | |||
TCPReadTimeout time.Duration `default:"1s" yaml:"tcp_read_timeout" json:"tcp_read_timeout,omitempty"` | |||
@@ -0,0 +1,68 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
model2 "github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &BranchRegisterRequestCodec{}) | |||
} | |||
type BranchRegisterRequestCodec struct { | |||
} | |||
func (g *BranchRegisterRequestCodec) Decode(in []byte) interface{} { | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg := message.BranchRegisterRequest{} | |||
length := ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.Xid = string(Read(buf, bytes)) | |||
} | |||
msg.BranchType = model2.BranchType(ReadByte(buf)) | |||
length = ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.ResourceId = string(Read(buf, bytes)) | |||
} | |||
length32 := ReadUInt32(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length32) | |||
msg.LockKey = string(Read(buf, bytes)) | |||
} | |||
length32 = ReadUInt32(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length32) | |||
msg.ApplicationData = Read(buf, bytes) | |||
} | |||
return msg | |||
} | |||
func (c *BranchRegisterRequestCodec) Encode(in interface{}) []byte { | |||
buf := goetty.NewByteBuf(0) | |||
req, _ := in.(message.BranchRegisterRequest) | |||
Write16String(req.Xid, buf) | |||
buf.WriteByte(byte(req.BranchType)) | |||
Write16String(req.ResourceId, buf) | |||
Write32String(req.LockKey, buf) | |||
Write32String(string(req.ApplicationData), buf) | |||
return buf.RawBuf() | |||
} | |||
func (g *BranchRegisterRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_BranchRegister | |||
} |
@@ -0,0 +1,74 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &BranchRegisterResponseCodec{}) | |||
} | |||
type BranchRegisterResponseCodec struct { | |||
} | |||
func (g *BranchRegisterResponseCodec) Decode(in []byte) interface{} { | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg := message.BranchRegisterResponse{} | |||
resultCode := ReadByte(buf) | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length := ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.Msg = string(Read(buf, bytes)) | |||
} | |||
} | |||
exceptionCode := ReadByte(buf) | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
msg.BranchId = int64(ReadUInt64(buf)) | |||
return msg | |||
} | |||
func (c *BranchRegisterResponseCodec) Encode(in interface{}) []byte { | |||
buf := goetty.NewByteBuf(0) | |||
resp, _ := in.(message.BranchRegisterResponse) | |||
resultCode := ReadByte(buf) | |||
if resultCode == byte(message.ResultCodeFailed) { | |||
var msg string | |||
if len(resp.Msg) > 128 { | |||
msg = resp.Msg[:128] | |||
} else { | |||
msg = resp.Msg | |||
} | |||
Write16String(msg, buf) | |||
} | |||
buf.WriteByte(byte(resp.TransactionExceptionCode)) | |||
branchID := uint64(resp.BranchId) | |||
branchIdBytes := []byte{ | |||
byte(branchID >> 56), | |||
byte(branchID >> 48), | |||
byte(branchID >> 40), | |||
byte(branchID >> 32), | |||
byte(branchID >> 24), | |||
byte(branchID >> 16), | |||
byte(branchID >> 8), | |||
byte(branchID), | |||
} | |||
buf.Write(branchIdBytes) | |||
return buf.RawBuf() | |||
} | |||
func (g *BranchRegisterResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_BranchRegisterResult | |||
} |
@@ -2,244 +2,100 @@ package codec | |||
import ( | |||
"bytes" | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"sync" | |||
) | |||
import ( | |||
"vimagination.zapto.org/byteio" | |||
) | |||
type SerializerType byte | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
type CodecType byte | |||
// TODO 待重构 | |||
const ( | |||
SEATA = byte(0x1) | |||
PROTOBUF = byte(0x2) | |||
KRYO = byte(0x4) | |||
FST = byte(0x8) | |||
CodeTypeSeata = CodecType(0x1) | |||
CodeTypeProtobuf = CodecType(0x2) | |||
CodeTypeKRYO = CodecType(0x4) | |||
CodeTypeFST = CodecType(0x8) | |||
) | |||
type Encoder func(in interface{}) []byte | |||
type Codec interface { | |||
Encode(in interface{}) []byte | |||
Decode(in []byte) interface{} | |||
GetMessageType() message.MessageType | |||
} | |||
type Decoder func(in []byte) (interface{}, int) | |||
var ( | |||
codecManager *CodecManager | |||
onceCodecManager = &sync.Once{} | |||
) | |||
func MessageEncoder(codecType byte, in interface{}) []byte { | |||
switch codecType { | |||
case SEATA: | |||
return SeataEncoder(in) | |||
default: | |||
log.Errorf("not support codecType, %s", codecType) | |||
return nil | |||
func GetCodecManager() *CodecManager { | |||
if codecManager == nil { | |||
onceCodecManager.Do(func() { | |||
codecManager = &CodecManager{ | |||
codecMap: make(map[CodecType]map[message.MessageType]Codec, 0), | |||
} | |||
}) | |||
} | |||
return codecManager | |||
} | |||
func MessageDecoder(codecType byte, in []byte) (interface{}, int) { | |||
switch codecType { | |||
case SEATA: | |||
return SeataDecoder(in) | |||
default: | |||
log.Errorf("not support codecType, %s", codecType) | |||
return nil, 0 | |||
} | |||
type CodecManager struct { | |||
mutex sync.Mutex | |||
codecMap map[CodecType]map[message.MessageType]Codec | |||
} | |||
func SeataEncoder(in interface{}) []byte { | |||
var result = make([]byte, 0) | |||
msg := in.(message.MessageTypeAware) | |||
typeCode := msg.GetTypeCode() | |||
encoder := getMessageEncoder(typeCode) | |||
func (c *CodecManager) RegisterCodec(codecType CodecType, codec Codec) { | |||
c.mutex.Lock() | |||
defer c.mutex.Unlock() | |||
codecTypeMap := c.codecMap[codecType] | |||
if codecTypeMap == nil { | |||
codecTypeMap = make(map[message.MessageType]Codec, 0) | |||
c.codecMap[codecType] = codecTypeMap | |||
} | |||
codecTypeMap[codec.GetMessageType()] = codec | |||
} | |||
typeC := uint16(typeCode) | |||
if encoder != nil { | |||
body := encoder(in) | |||
result = append(result, []byte{byte(typeC >> 8), byte(typeC)}...) | |||
result = append(result, body...) | |||
func (c *CodecManager) GetCodec(codecType CodecType, msgType message.MessageType) Codec { | |||
if m := c.codecMap[codecType]; m != nil { | |||
return m[msgType] | |||
} | |||
return result | |||
return nil | |||
} | |||
func SeataDecoder(in []byte) (interface{}, int) { | |||
func (c *CodecManager) Decode(codecType CodecType, in []byte) interface{} { | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
typeCode, _, _ := r.ReadInt16() | |||
codec := c.GetCodec(codecType, message.MessageType(typeCode)) | |||
decoder := getMessageDecoder(message.MessageType(typeCode)) | |||
if decoder != nil { | |||
return decoder(in[2:]) | |||
} | |||
return nil, 0 | |||
} | |||
func getMessageEncoder(typeCode message.MessageType) Encoder { | |||
switch typeCode { | |||
case message.MessageType_SeataMerge: | |||
return MergedWarpMessageEncoder | |||
case message.MessageType_SeataMergeResult: | |||
return MergeResultMessageEncoder | |||
case message.MessageType_RegClt: | |||
return RegisterTMRequestEncoder | |||
case message.MessageType_RegCltResult: | |||
return RegisterTMResponseEncoder | |||
case message.MessageType_RegRm: | |||
return RegisterRMRequestEncoder | |||
case message.MessageType_RegRmResult: | |||
return RegisterRMResponseEncoder | |||
case message.MessageType_BranchCommit: | |||
return BranchCommitRequestEncoder | |||
case message.MessageType_BranchRollback: | |||
return BranchRollbackRequestEncoder | |||
case message.MessageType_GlobalReport: | |||
return GlobalReportRequestEncoder | |||
default: | |||
var encoder Encoder | |||
encoder = getMergeRequestMessageEncoder(typeCode) | |||
if encoder != nil { | |||
return encoder | |||
} | |||
encoder = getMergeResponseMessageEncoder(typeCode) | |||
if encoder != nil { | |||
return encoder | |||
} | |||
log.Errorf("not support typeCode, %d", typeCode) | |||
if codec == nil { | |||
log.Errorf("This message type [%v] has no codec to decode", typeCode) | |||
return nil | |||
} | |||
return codec.Decode(in[2:]) | |||
} | |||
func getMergeRequestMessageEncoder(typeCode message.MessageType) Encoder { | |||
switch typeCode { | |||
case message.MessageType_GlobalBegin: | |||
return GlobalBeginRequestEncoder | |||
case message.MessageType_GlobalCommit: | |||
return GlobalCommitRequestEncoder | |||
case message.MessageType_GlobalRollback: | |||
return GlobalRollbackRequestEncoder | |||
case message.MessageType_GlobalStatus: | |||
return GlobalStatusRequestEncoder | |||
case message.MessageType_GlobalLockQuery: | |||
return GlobalLockQueryRequestEncoder | |||
case message.MessageType_BranchRegister: | |||
return BranchRegisterRequestEncoder | |||
case message.MessageType_BranchStatusReport: | |||
return BranchReportRequestEncoder | |||
case message.MessageType_GlobalReport: | |||
return GlobalReportRequestEncoder | |||
default: | |||
break | |||
} | |||
return nil | |||
} | |||
func getMergeResponseMessageEncoder(typeCode message.MessageType) Encoder { | |||
switch typeCode { | |||
case message.MessageType_GlobalBeginResult: | |||
return GlobalBeginResponseEncoder | |||
case message.MessageType_GlobalCommitResult: | |||
return GlobalCommitResponseEncoder | |||
case message.MessageType_GlobalRollbackResult: | |||
return GlobalRollbackResponseEncoder | |||
case message.MessageType_GlobalStatusResult: | |||
return GlobalStatusResponseEncoder | |||
case message.MessageType_GlobalLockQueryResult: | |||
return GlobalLockQueryResponseEncoder | |||
case message.MessageType_BranchRegisterResult: | |||
return BranchRegisterResponseEncoder | |||
case message.MessageType_BranchStatusReportResult: | |||
return BranchReportResponseEncoder | |||
case message.MessageType_BranchCommitResult: | |||
return BranchCommitResponseEncoder | |||
case message.MessageType_BranchRollbackResult: | |||
return BranchRollbackResponseEncoder | |||
case message.MessageType_GlobalReportResult: | |||
return GlobalReportResponseEncoder | |||
default: | |||
break | |||
} | |||
return nil | |||
} | |||
func (c *CodecManager) Encode(codecType CodecType, in interface{}) []byte { | |||
var result = make([]byte, 0) | |||
msg := in.(message.MessageTypeAware) | |||
typeCode := msg.GetTypeCode() | |||
func getMessageDecoder(typeCode message.MessageType) Decoder { | |||
switch typeCode { | |||
case message.MessageType_SeataMerge: | |||
return MergedWarpMessageDecoder | |||
case message.MessageType_SeataMergeResult: | |||
return MergeResultMessageDecoder | |||
case message.MessageType_RegClt: | |||
return RegisterTMRequestDecoder | |||
case message.MessageType_RegCltResult: | |||
return RegisterTMResponseDecoder | |||
case message.MessageType_RegRm: | |||
return RegisterRMRequestDecoder | |||
case message.MessageType_RegRmResult: | |||
return RegisterRMResponseDecoder | |||
case message.MessageType_BranchCommit: | |||
return BranchCommitRequestDecoder | |||
case message.MessageType_BranchRollback: | |||
return BranchRollbackRequestDecoder | |||
case message.MessageType_GlobalReport: | |||
return GlobalReportRequestDecoder | |||
default: | |||
var Decoder Decoder | |||
Decoder = getMergeRequestMessageDecoder(typeCode) | |||
if Decoder != nil { | |||
return Decoder | |||
} | |||
Decoder = getMergeResponseMessageDecoder(typeCode) | |||
if Decoder != nil { | |||
return Decoder | |||
} | |||
log.Errorf("not support typeCode, %d", typeCode) | |||
codec := c.GetCodec(codecType, typeCode) | |||
if codec == nil { | |||
log.Errorf("This message type [%v] has no codec to encode", typeCode) | |||
return nil | |||
} | |||
} | |||
func getMergeRequestMessageDecoder(typeCode message.MessageType) Decoder { | |||
switch typeCode { | |||
case message.MessageType_GlobalBegin: | |||
return GlobalBeginRequestDecoder | |||
case message.MessageType_GlobalCommit: | |||
return GlobalCommitRequestDecoder | |||
case message.MessageType_GlobalRollback: | |||
return GlobalRollbackRequestDecoder | |||
case message.MessageType_GlobalStatus: | |||
return GlobalStatusRequestDecoder | |||
case message.MessageType_GlobalLockQuery: | |||
return GlobalLockQueryRequestDecoder | |||
case message.MessageType_BranchRegister: | |||
return BranchRegisterRequestDecoder | |||
case message.MessageType_BranchStatusReport: | |||
return BranchReportRequestDecoder | |||
case message.MessageType_GlobalReport: | |||
return GlobalReportRequestDecoder | |||
default: | |||
break | |||
} | |||
return nil | |||
} | |||
body := codec.Encode(in) | |||
typeC := uint16(typeCode) | |||
result = append(result, []byte{byte(typeC >> 8), byte(typeC)}...) | |||
result = append(result, body...) | |||
func getMergeResponseMessageDecoder(typeCode message.MessageType) Decoder { | |||
switch typeCode { | |||
case message.MessageType_GlobalBeginResult: | |||
return GlobalBeginResponseDecoder | |||
case message.MessageType_GlobalCommitResult: | |||
return GlobalCommitResponseDecoder | |||
case message.MessageType_GlobalRollbackResult: | |||
return GlobalRollbackResponseDecoder | |||
case message.MessageType_GlobalStatusResult: | |||
return GlobalStatusResponseDecoder | |||
case message.MessageType_GlobalLockQueryResult: | |||
return GlobalLockQueryResponseDecoder | |||
case message.MessageType_BranchRegisterResult: | |||
return BranchRegisterResponseDecoder | |||
case message.MessageType_BranchStatusReportResult: | |||
return BranchReportResponseDecoder | |||
case message.MessageType_BranchCommitResult: | |||
return BranchCommitResponseDecoder | |||
case message.MessageType_BranchRollbackResult: | |||
return BranchRollbackResponseDecoder | |||
case message.MessageType_GlobalReportResult: | |||
return GlobalReportResponseDecoder | |||
default: | |||
break | |||
} | |||
return nil | |||
return result | |||
} |
@@ -0,0 +1,189 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
"github.com/fagongzi/util/hack" | |||
) | |||
// Write16String write string value with 16 byte length | |||
func Write16String(value string, buf *goetty.ByteBuf) { | |||
if value != "" { | |||
buf.WriteUInt16(uint16(len(value))) | |||
buf.WriteString(value) | |||
} else { | |||
buf.WriteUInt16(uint16(0)) | |||
} | |||
} | |||
// Write16String write string value with 16 byte length | |||
func Write32String(value string, buf *goetty.ByteBuf) { | |||
if value != "" { | |||
buf.WriteUInt32(uint32(len(value))) | |||
buf.WriteString(value) | |||
} else { | |||
buf.WriteUInt32(uint32(0)) | |||
} | |||
} | |||
// Write8String write string value with 8 byte length | |||
func Write8String(value string, buf *goetty.ByteBuf) { | |||
if value != "" { | |||
buf.WriteByte(uint8(len(value))) | |||
buf.WriteString(value) | |||
} else { | |||
buf.WriteByte(uint8(0)) | |||
} | |||
} | |||
// ReadString read string value | |||
func ReadString(buf *goetty.ByteBuf) string { | |||
size := ReadUInt16(buf) | |||
if size == 0 { | |||
return "" | |||
} | |||
_, value, _ := buf.ReadBytes(int(size)) | |||
return hack.SliceToString(value) | |||
} | |||
// MaybeReadString maybe read string value | |||
func MaybeReadString(buf *goetty.ByteBuf) (string, bool) { | |||
if buf.Readable() < 2 { | |||
return "", false | |||
} | |||
size := ReadUInt16(buf) | |||
if size == 0 { | |||
return "", true | |||
} | |||
if buf.Readable() < int(size) { | |||
return "", false | |||
} | |||
_, value, _ := buf.ReadBytes(int(size)) | |||
return hack.SliceToString(value), true | |||
} | |||
// WriteBigString write big string | |||
func WriteBigString(value string, buf *goetty.ByteBuf) { | |||
if value != "" { | |||
buf.WriteInt(len(value)) | |||
buf.WriteString(value) | |||
} else { | |||
buf.WriteInt(0) | |||
} | |||
} | |||
// ReadBigString read big string | |||
func ReadBigString(buf *goetty.ByteBuf) string { | |||
size := ReadInt(buf) | |||
if size == 0 { | |||
return "" | |||
} | |||
_, value, _ := buf.ReadBytes(size) | |||
return hack.SliceToString(value) | |||
} | |||
// MaybeReadBigString maybe read string value | |||
func MaybeReadBigString(buf *goetty.ByteBuf) (string, bool) { | |||
if buf.Readable() < 4 { | |||
return "", false | |||
} | |||
size := ReadInt(buf) | |||
if size == 0 { | |||
return "", true | |||
} | |||
if buf.Readable() < size { | |||
return "", false | |||
} | |||
_, value, _ := buf.ReadBytes(int(size)) | |||
return hack.SliceToString(value), true | |||
} | |||
// ReadUInt64 read uint64 value | |||
func ReadUInt64(buf *goetty.ByteBuf) uint64 { | |||
value, _ := buf.ReadUInt64() | |||
return value | |||
} | |||
// ReadUInt16 read uint16 value | |||
func ReadUInt16(buf *goetty.ByteBuf) uint16 { | |||
value, _ := buf.ReadUInt16() | |||
return value | |||
} | |||
// ReadUInt32 read uint16 value | |||
func ReadUInt32(buf *goetty.ByteBuf) uint32 { | |||
value, _ := buf.ReadUInt32() | |||
return value | |||
} | |||
// ReadUInt32 read uint16 value | |||
func Read(buf *goetty.ByteBuf, p []byte) []byte { | |||
buf.Read(p) | |||
return p | |||
} | |||
// ReadInt read int value | |||
func ReadInt(buf *goetty.ByteBuf) int { | |||
value, _ := buf.ReadInt() | |||
return value | |||
} | |||
// ReadByte read byte value | |||
func ReadByte(buf *goetty.ByteBuf) byte { | |||
value, _ := buf.ReadByte() | |||
return value | |||
} | |||
// ReadBytes read bytes value | |||
func ReadBytes(n int, buf *goetty.ByteBuf) []byte { | |||
_, value, _ := buf.ReadBytes(n) | |||
return value | |||
} | |||
// WriteBool write bool value | |||
func WriteBool(value bool, out *goetty.ByteBuf) { | |||
out.WriteByte(boolToByte(value)) | |||
} | |||
// WriteSlice write slice value | |||
func WriteSlice(value []byte, buf *goetty.ByteBuf) { | |||
buf.WriteUInt16(uint16(len(value))) | |||
if len(value) > 0 { | |||
buf.Write(value) | |||
} | |||
} | |||
// ReadSlice read slice value | |||
func ReadSlice(buf *goetty.ByteBuf) []byte { | |||
l, _ := buf.ReadUInt16() | |||
if l == 0 { | |||
return nil | |||
} | |||
_, data, _ := buf.ReadBytes(int(l)) | |||
return data | |||
} | |||
func boolToByte(value bool) byte { | |||
if value { | |||
return 1 | |||
} | |||
return 0 | |||
} | |||
func byteToBool(value byte) bool { | |||
if value == 1 { | |||
return true | |||
} | |||
return false | |||
} |
@@ -0,0 +1,51 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
type CommonGlobalEndRequestCodec struct { | |||
} | |||
func (c *CommonGlobalEndRequestCodec) Encode(in interface{}) []byte { | |||
req, _ := in.(message.AbstractGlobalEndRequest) | |||
buf := goetty.NewByteBuf(0) | |||
Write16String(req.Xid, buf) | |||
Write16String(string(req.ExtraData), buf) | |||
return buf.RawBuf() | |||
} | |||
func (c *CommonGlobalEndRequestCodec) Decode(in []byte) interface{} { | |||
res := message.AbstractGlobalEndRequest{} | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
var xidLen int | |||
if buf.Readable() >= 2 { | |||
xidLen = int(ReadUInt16(buf)) | |||
} | |||
if buf.Readable() >= xidLen { | |||
xidBytes := make([]byte, xidLen) | |||
xidBytes = Read(buf, xidBytes) | |||
res.Xid = string(xidBytes) | |||
} | |||
var extraDataLen int | |||
if buf.Readable() >= 2 { | |||
extraDataLen = int(ReadUInt16(buf)) | |||
} | |||
if buf.Readable() >= extraDataLen { | |||
extraDataBytes := make([]byte, xidLen) | |||
extraDataBytes = Read(buf, extraDataBytes) | |||
res.ExtraData = extraDataBytes | |||
} | |||
return res | |||
} |
@@ -0,0 +1,58 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
) | |||
type CommonGlobalEndResponseCodec struct { | |||
} | |||
func (c *CommonGlobalEndResponseCodec) Encode(in interface{}) []byte { | |||
buf := goetty.NewByteBuf(0) | |||
resp := in.(message.AbstractGlobalEndResponse) | |||
buf.WriteByte(byte(resp.ResultCode)) | |||
if resp.ResultCode == message.ResultCodeFailed { | |||
var msg string | |||
if len(resp.Msg) > 128 { | |||
msg = resp.Msg[:128] | |||
} else { | |||
msg = resp.Msg | |||
} | |||
Write16String(msg, buf) | |||
} | |||
buf.WriteByte(byte(resp.TransactionExceptionCode)) | |||
buf.WriteByte(byte(resp.GlobalStatus)) | |||
return buf.RawBuf() | |||
} | |||
func (c *CommonGlobalEndResponseCodec) Decode(in []byte) interface{} { | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg := message.AbstractGlobalEndResponse{} | |||
resultCode := ReadByte(buf) | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length := ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.Msg = string(Read(buf, bytes)) | |||
} | |||
} | |||
exceptionCode := ReadByte(buf) | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
globalStatus := ReadByte(buf) | |||
msg.GlobalStatus = transaction.GlobalStatus(globalStatus) | |||
return msg | |||
} |
@@ -0,0 +1,72 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
type AbstractIdentifyRequestCodec struct { | |||
} | |||
func (c *AbstractIdentifyRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.AbstractIdentifyRequest) | |||
buf := goetty.NewByteBuf(0) | |||
Write16String(req.Version, buf) | |||
Write16String(req.ApplicationId, buf) | |||
Write16String(req.TransactionServiceGroup, buf) | |||
Write16String(string(req.ExtraData), buf) | |||
return buf.RawBuf() | |||
} | |||
func (c *AbstractIdentifyRequestCodec) Decode(in []byte) interface{} { | |||
msg := message.AbstractIdentifyRequest{} | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
var len uint16 | |||
if buf.Readable() < 2 { | |||
return msg | |||
} | |||
len = ReadUInt16(buf) | |||
if uint16(buf.Readable()) < len { | |||
return msg | |||
} | |||
versionBytes := make([]byte, len) | |||
msg.Version = string(Read(buf, versionBytes)) | |||
if buf.Readable() < 2 { | |||
return msg | |||
} | |||
len = ReadUInt16(buf) | |||
if uint16(buf.Readable()) < len { | |||
return msg | |||
} | |||
applicationIdBytes := make([]byte, len) | |||
msg.ApplicationId = string(Read(buf, applicationIdBytes)) | |||
if buf.Readable() < 2 { | |||
return msg | |||
} | |||
len = ReadUInt16(buf) | |||
if uint16(buf.Readable()) < len { | |||
return msg | |||
} | |||
transactionServiceGroupBytes := make([]byte, len) | |||
msg.TransactionServiceGroup = string(Read(buf, transactionServiceGroupBytes)) | |||
if buf.Readable() < 2 { | |||
return msg | |||
} | |||
len = ReadUInt16(buf) | |||
if len > 0 && uint16(buf.Readable()) > len { | |||
extraDataBytes := make([]byte, len) | |||
msg.ExtraData = Read(buf, extraDataBytes) | |||
} | |||
return msg | |||
} |
@@ -0,0 +1,47 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
type AbstractIdentifyResponseCodec struct { | |||
} | |||
func (c *AbstractIdentifyResponseCodec) Encode(in interface{}) []byte { | |||
buf := goetty.NewByteBuf(0) | |||
resp := in.(message.AbstractIdentifyResponse) | |||
if resp.Identified { | |||
buf.WriteByte(byte(1)) | |||
} else { | |||
buf.WriteByte(byte(0)) | |||
} | |||
Write16String(resp.Version, buf) | |||
return buf.RawBuf() | |||
} | |||
func (c *AbstractIdentifyResponseCodec) Decode(in []byte) interface{} { | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg := message.AbstractIdentifyResponse{} | |||
identified, _ := buf.ReadByte() | |||
if identified == byte(1) { | |||
msg.Identified = true | |||
} else if identified == byte(0) { | |||
msg.Identified = false | |||
} | |||
length := ReadUInt16(buf) | |||
if length > 0 { | |||
versionBytes := make([]byte, length) | |||
msg.Version = string(Read(buf, versionBytes)) | |||
} | |||
return msg | |||
} |
@@ -0,0 +1,45 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalBeginRequestCodec{}) | |||
} | |||
type GlobalBeginRequestCodec struct { | |||
} | |||
func (c *GlobalBeginRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.GlobalBeginRequest) | |||
buf := goetty.NewByteBuf(0) | |||
buf.WriteUInt32(uint32(req.Timeout)) | |||
Write16String(req.TransactionName, buf) | |||
return buf.RawBuf() | |||
} | |||
func (g *GlobalBeginRequestCodec) Decode(in []byte) interface{} { | |||
msg := message.GlobalBeginRequest{} | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg.Timeout = int32(ReadUInt32(buf)) | |||
len := ReadUInt16(buf) | |||
if len > 0 { | |||
transactionName := make([]byte, len) | |||
msg.TransactionName = string(Read(buf, transactionName)) | |||
} | |||
return msg | |||
} | |||
func (g *GlobalBeginRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalBegin | |||
} |
@@ -0,0 +1,76 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalBeginResponseCodec{}) | |||
} | |||
type GlobalBeginResponseCodec struct { | |||
} | |||
func (c *GlobalBeginResponseCodec) Encode(in interface{}) []byte { | |||
buf := goetty.NewByteBuf(0) | |||
resp := in.(message.GlobalBeginResponse) | |||
buf.WriteByte(byte(resp.ResultCode)) | |||
if resp.ResultCode == message.ResultCodeFailed { | |||
var msg string | |||
if len(resp.Msg) > 128 { | |||
msg = resp.Msg[:128] | |||
} else { | |||
msg = resp.Msg | |||
} | |||
Write16String(msg, buf) | |||
} | |||
buf.WriteByte(byte(resp.TransactionExceptionCode)) | |||
Write16String(resp.Xid, buf) | |||
Write16String(string(resp.ExtraData), buf) | |||
return buf.RawBuf() | |||
} | |||
func (g *GlobalBeginResponseCodec) Decode(in []byte) interface{} { | |||
var lenth uint16 | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg := message.GlobalBeginResponse{} | |||
resultCode := ReadByte(buf) | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
lenth = ReadUInt16(buf) | |||
if lenth > 0 { | |||
bytes := make([]byte, lenth) | |||
msg.Msg = string(Read(buf, bytes)) | |||
} | |||
} | |||
exceptionCode := ReadByte(buf) | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
lenth = ReadUInt16(buf) | |||
if lenth > 0 { | |||
bytes := make([]byte, lenth) | |||
msg.Xid = string(Read(buf, bytes)) | |||
} | |||
lenth = ReadUInt16(buf) | |||
if lenth > 0 { | |||
bytes := make([]byte, lenth) | |||
msg.ExtraData = Read(buf, bytes) | |||
} | |||
return msg | |||
} | |||
func (g *GlobalBeginResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalBeginResult | |||
} |
@@ -0,0 +1,30 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalCommitRequestCodec{}) | |||
} | |||
type GlobalCommitRequestCodec struct { | |||
CommonGlobalEndRequestCodec | |||
} | |||
func (g *GlobalCommitRequestCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndRequestCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
return message.GlobalCommitRequest{ | |||
AbstractGlobalEndRequest: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalCommitRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.GlobalCommitRequest) | |||
return g.CommonGlobalEndRequestCodec.Encode(req.AbstractGlobalEndRequest) | |||
} | |||
func (g *GlobalCommitRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalCommit | |||
} |
@@ -0,0 +1,25 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalCommitResponseCodec{}) | |||
} | |||
type GlobalCommitResponseCodec struct { | |||
CommonGlobalEndResponseCodec | |||
} | |||
func (g *GlobalCommitResponseCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndResponseCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndResponse) | |||
return message.GlobalCommitResponse{ | |||
AbstractGlobalEndResponse: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalCommitResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalCommitResult | |||
} |
@@ -0,0 +1,21 @@ | |||
package codec | |||
//func init() { | |||
// GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalReportRequestCodec{}) | |||
//} | |||
// | |||
//type GlobalReportRequestCodec struct { | |||
// CommonGlobalEndRequestCodec | |||
//} | |||
// | |||
//func (g *GlobalReportRequestCodec) Decode(in []byte) interface{} { | |||
// req := g.CommonGlobalEndRequestCodec.Decode(in) | |||
// abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
// return message.GlobalCommitRequest{ | |||
// AbstractGlobalEndRequest: abstractGlobalEndRequest, | |||
// } | |||
//} | |||
// | |||
//func (g *GlobalReportRequestCodec) GetMessageType() message.MessageType { | |||
// return message.MessageType_GlobalCommit | |||
//} |
@@ -0,0 +1,25 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalReportResponseCodec{}) | |||
} | |||
type GlobalReportResponseCodec struct { | |||
CommonGlobalEndResponseCodec | |||
} | |||
func (g *GlobalReportResponseCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndResponseCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndResponse) | |||
return message.GlobalReportResponse{ | |||
AbstractGlobalEndResponse: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalReportResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalReportResult | |||
} |
@@ -0,0 +1,30 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalRollbackRequestCodec{}) | |||
} | |||
type GlobalRollbackRequestCodec struct { | |||
CommonGlobalEndRequestCodec | |||
} | |||
func (g *GlobalRollbackRequestCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndRequestCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
return message.GlobalCommitRequest{ | |||
AbstractGlobalEndRequest: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalRollbackRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.GlobalRollbackRequest) | |||
return g.CommonGlobalEndRequestCodec.Encode(req.AbstractGlobalEndRequest) | |||
} | |||
func (g *GlobalRollbackRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalRollback | |||
} |
@@ -0,0 +1,25 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalRollbackResponseCodec{}) | |||
} | |||
type GlobalRollbackResponseCodec struct { | |||
CommonGlobalEndResponseCodec | |||
} | |||
func (g *GlobalRollbackResponseCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndResponseCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndResponse) | |||
return message.GlobalRollbackResponse{ | |||
AbstractGlobalEndResponse: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalRollbackResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalRollbackResult | |||
} |
@@ -0,0 +1,30 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalStatusRequestCodec{}) | |||
} | |||
type GlobalStatusRequestCodec struct { | |||
CommonGlobalEndRequestCodec | |||
} | |||
func (g *GlobalStatusRequestCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndRequestCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
return message.GlobalStatusRequest{ | |||
AbstractGlobalEndRequest: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalStatusRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.GlobalStatusRequest) | |||
return g.CommonGlobalEndRequestCodec.Encode(req.AbstractGlobalEndRequest) | |||
} | |||
func (g *GlobalStatusRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalStatus | |||
} |
@@ -0,0 +1,25 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &GlobalStatusResponseCodec{}) | |||
} | |||
type GlobalStatusResponseCodec struct { | |||
CommonGlobalEndResponseCodec | |||
} | |||
func (g *GlobalStatusResponseCodec) Decode(in []byte) interface{} { | |||
req := g.CommonGlobalEndResponseCodec.Decode(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndResponse) | |||
return message.GlobalStatusResponse{ | |||
AbstractGlobalEndResponse: abstractGlobalEndRequest, | |||
} | |||
} | |||
func (g *GlobalStatusResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_GlobalStatusResult | |||
} |
@@ -0,0 +1,71 @@ | |||
package codec | |||
import ( | |||
"github.com/fagongzi/goetty" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &RegisterRMRequestCodec{}) | |||
} | |||
type RegisterRMRequestCodec struct { | |||
} | |||
func (g *RegisterRMRequestCodec) Decode(in []byte) interface{} { | |||
buf := goetty.NewByteBuf(len(in)) | |||
buf.Write(in) | |||
msg := message.RegisterRMRequest{} | |||
length := ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.Version = string(Read(buf, bytes)) | |||
} | |||
length = ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.ApplicationId = string(Read(buf, bytes)) | |||
} | |||
length = ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.TransactionServiceGroup = string(Read(buf, bytes)) | |||
} | |||
length = ReadUInt16(buf) | |||
if length > 0 { | |||
bytes := make([]byte, length) | |||
msg.ExtraData = Read(buf, bytes) | |||
} | |||
length32 := ReadUInt32(buf) | |||
if length32 > 0 { | |||
bytes := make([]byte, length32) | |||
msg.ResourceIds = string(Read(buf, bytes)) | |||
} | |||
return msg | |||
} | |||
func (c *RegisterRMRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.RegisterRMRequest) | |||
buf := goetty.NewByteBuf(0) | |||
Write16String(req.Version, buf) | |||
Write16String(req.ApplicationId, buf) | |||
Write16String(req.TransactionServiceGroup, buf) | |||
Write16String(string(req.ExtraData), buf) | |||
Write16String(req.ResourceIds, buf) | |||
return buf.RawBuf() | |||
} | |||
func (g *RegisterRMRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_RegRm | |||
} |
@@ -0,0 +1,25 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &RegisterRMResponseCodec{}) | |||
} | |||
type RegisterRMResponseCodec struct { | |||
AbstractIdentifyResponseCodec | |||
} | |||
func (g *RegisterRMResponseCodec) Decode(in []byte) interface{} { | |||
req := g.AbstractIdentifyResponseCodec.Decode(in) | |||
abstractIdentifyResponse := req.(message.AbstractIdentifyResponse) | |||
return message.RegisterRMResponse{ | |||
AbstractIdentifyResponse: abstractIdentifyResponse, | |||
} | |||
} | |||
func (g *RegisterRMResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_RegRmResult | |||
} |
@@ -0,0 +1,30 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &RegisterTMRequestCodec{}) | |||
} | |||
type RegisterTMRequestCodec struct { | |||
AbstractIdentifyRequestCodec | |||
} | |||
func (g *RegisterTMRequestCodec) Decode(in []byte) interface{} { | |||
req := g.AbstractIdentifyRequestCodec.Decode(in) | |||
abstractIdentifyRequest := req.(message.AbstractIdentifyRequest) | |||
return message.RegisterTMRequest{ | |||
AbstractIdentifyRequest: abstractIdentifyRequest, | |||
} | |||
} | |||
func (c *RegisterTMRequestCodec) Encode(in interface{}) []byte { | |||
req := in.(message.RegisterTMRequest) | |||
return c.AbstractIdentifyRequestCodec.Encode(req.AbstractIdentifyRequest) | |||
} | |||
func (g *RegisterTMRequestCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_RegClt | |||
} |
@@ -0,0 +1,30 @@ | |||
package codec | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
func init() { | |||
GetCodecManager().RegisterCodec(CodeTypeSeata, &RegisterTMResponseCodec{}) | |||
} | |||
type RegisterTMResponseCodec struct { | |||
AbstractIdentifyResponseCodec | |||
} | |||
func (g *RegisterTMResponseCodec) Decode(in []byte) interface{} { | |||
req := g.AbstractIdentifyResponseCodec.Decode(in) | |||
abstractIdentifyResponse := req.(message.AbstractIdentifyResponse) | |||
return message.RegisterTMResponse{ | |||
AbstractIdentifyResponse: abstractIdentifyResponse, | |||
} | |||
} | |||
func (c *RegisterTMResponseCodec) Encode(in interface{}) []byte { | |||
resp := in.(message.RegisterTMResponse) | |||
return c.AbstractIdentifyResponseCodec.Encode(resp.AbstractIdentifyResponse) | |||
} | |||
func (g *RegisterTMResponseCodec) GetMessageType() message.MessageType { | |||
return message.MessageType_RegCltResult | |||
} |
@@ -1,779 +0,0 @@ | |||
package codec | |||
import ( | |||
"bytes" | |||
model2 "github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
) | |||
import ( | |||
"vimagination.zapto.org/byteio" | |||
) | |||
// TODO 待重构 | |||
func AbstractResultMessageDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractResultMessage{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
totalReadN += 1 | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
return msg, totalReadN | |||
} | |||
func MergedWarpMessageDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
size16 int16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
result := message.MergedWarpMessage{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
r.ReadInt32() | |||
totalReadN += 4 | |||
size16, readN, _ = r.ReadInt16() | |||
totalReadN += readN | |||
result.Msgs = make([]message.MessageTypeAware, 0) | |||
for index := 0; index < int(size16); index++ { | |||
typeCode, _, _ := r.ReadInt16() | |||
totalReadN += 2 | |||
decoder := getMessageDecoder(message.MessageType(typeCode)) | |||
if decoder != nil { | |||
msg, readN := decoder(in[totalReadN:]) | |||
totalReadN += readN | |||
result.Msgs = append(result.Msgs, msg.(message.MessageTypeAware)) | |||
} | |||
} | |||
return result, totalReadN | |||
} | |||
func MergeResultMessageDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
size16 int16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
result := message.MergeResultMessage{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
r.ReadInt32() | |||
totalReadN += 4 | |||
size16, readN, _ = r.ReadInt16() | |||
totalReadN += readN | |||
result.Msgs = make([]message.MessageTypeAware, 0) | |||
for index := 0; index < int(size16); index++ { | |||
typeCode, _, _ := r.ReadInt16() | |||
totalReadN += 2 | |||
decoder := getMessageDecoder(message.MessageType(typeCode)) | |||
if decoder != nil { | |||
msg, readN := decoder(in[totalReadN:]) | |||
totalReadN += readN | |||
result.Msgs = append(result.Msgs, msg.(message.MessageTypeAware)) | |||
} | |||
} | |||
return result, totalReadN | |||
} | |||
func AbstractIdentifyRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractIdentifyRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Version, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ApplicationId, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.TransactionServiceGroup, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ExtraData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ExtraData) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func AbstractIdentifyResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractIdentifyResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
identified, _ := r.ReadByte() | |||
totalReadN += 1 | |||
if identified == byte(1) { | |||
msg.Identified = true | |||
} else if identified == byte(0) { | |||
msg.Identified = false | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Version, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func RegisterRMRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length32 uint32 = 0 | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.RegisterRMRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Version, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ApplicationId, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.TransactionServiceGroup, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ExtraData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ExtraData) | |||
totalReadN += readN | |||
} | |||
length32, readN, _ = r.ReadUint32() | |||
totalReadN += readN | |||
if length32 > 0 { | |||
msg.ResourceIds, readN, _ = r.ReadString(int(length32)) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func RegisterRMResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractIdentifyResponseDecoder(in) | |||
abstractIdentifyResponse := resp.(message.AbstractIdentifyResponse) | |||
msg := message.RegisterRMResponse{AbstractIdentifyResponse: abstractIdentifyResponse} | |||
return msg, totalReadN | |||
} | |||
func RegisterTMRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := AbstractIdentifyRequestDecoder(in) | |||
abstractIdentifyRequest := req.(message.AbstractIdentifyRequest) | |||
msg := message.RegisterTMRequest{AbstractIdentifyRequest: abstractIdentifyRequest} | |||
return msg, totalReadN | |||
} | |||
func RegisterTMResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractIdentifyResponseDecoder(in) | |||
abstractIdentifyResponse := resp.(message.AbstractIdentifyResponse) | |||
msg := message.RegisterRMResponse{AbstractIdentifyResponse: abstractIdentifyResponse} | |||
return msg, totalReadN | |||
} | |||
func AbstractTransactionResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractTransactionResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
exceptionCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
return msg, totalReadN | |||
} | |||
func AbstractBranchEndRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractBranchEndRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
msg.BranchId, _, _ = r.ReadInt64() | |||
totalReadN += 8 | |||
branchType, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.BranchType = model2.BranchType(branchType) | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ResourceId, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ApplicationData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ApplicationData) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func AbstractBranchEndResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractBranchEndResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
exceptionCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
msg.BranchId, _, _ = r.ReadInt64() | |||
totalReadN += 8 | |||
branchStatus, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.BranchStatus = model2.BranchStatus(branchStatus) | |||
return msg, totalReadN | |||
} | |||
func AbstractGlobalEndRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractGlobalEndRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ExtraData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ExtraData) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func AbstractGlobalEndResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.AbstractGlobalEndResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
exceptionCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
globalStatus, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.GlobalStatus = transaction.GlobalStatus(globalStatus) | |||
return msg, totalReadN | |||
} | |||
func BranchCommitRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := AbstractBranchEndRequestDecoder(in) | |||
abstractBranchEndRequest := req.(message.AbstractBranchEndRequest) | |||
msg := message.BranchCommitRequest{AbstractBranchEndRequest: abstractBranchEndRequest} | |||
return msg, totalReadN | |||
} | |||
func BranchCommitResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractBranchEndResponseDecoder(in) | |||
abstractBranchEndResponse := resp.(message.AbstractBranchEndResponse) | |||
msg := message.BranchCommitResponse{AbstractBranchEndResponse: abstractBranchEndResponse} | |||
return msg, totalReadN | |||
} | |||
func BranchRegisterRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length32 uint32 = 0 | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.BranchRegisterRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
branchType, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.BranchType = model2.BranchType(branchType) | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ResourceId, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length32, readN, _ = r.ReadUint32() | |||
totalReadN += readN | |||
if length32 > 0 { | |||
msg.LockKey, readN, _ = r.ReadString(int(length32)) | |||
totalReadN += readN | |||
} | |||
length32, readN, _ = r.ReadUint32() | |||
totalReadN += readN | |||
if length32 > 0 { | |||
msg.ApplicationData = make([]byte, int(length32)) | |||
readN, _ := r.Read(msg.ApplicationData) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func BranchRegisterResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.BranchRegisterResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
exceptionCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
msg.BranchId, readN, _ = r.ReadInt64() | |||
totalReadN += readN | |||
return msg, totalReadN | |||
} | |||
func BranchReportRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.BranchReportRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
msg.BranchId, _, _ = r.ReadInt64() | |||
branchStatus, _ := r.ReadByte() | |||
msg.Status = model2.BranchStatus(branchStatus) | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ResourceId, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ApplicationData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ApplicationData) | |||
totalReadN += readN | |||
} | |||
branchType, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.BranchType = model2.BranchType(branchType) | |||
return msg, totalReadN | |||
} | |||
func BranchReportResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractTransactionResponseDecoder(in) | |||
abstractTransactionResponse := resp.(message.AbstractTransactionResponse) | |||
msg := message.BranchReportResponse{AbstractTransactionResponse: abstractTransactionResponse} | |||
return msg, totalReadN | |||
} | |||
func BranchRollbackRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := AbstractBranchEndRequestDecoder(in) | |||
abstractBranchEndRequest := req.(message.AbstractBranchEndRequest) | |||
msg := message.BranchRollbackRequest{AbstractBranchEndRequest: abstractBranchEndRequest} | |||
return msg, totalReadN | |||
} | |||
func BranchRollbackResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractBranchEndResponseDecoder(in) | |||
abstractBranchEndResponse := resp.(message.AbstractBranchEndResponse) | |||
msg := message.BranchRollbackResponse{AbstractBranchEndResponse: abstractBranchEndResponse} | |||
return msg, totalReadN | |||
} | |||
func GlobalBeginRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.GlobalBeginRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
timeout, readN, _ := r.ReadInt32() | |||
totalReadN += readN | |||
msg.Timeout = timeout | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.TransactionName, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func GlobalBeginResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.GlobalBeginResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
exceptionCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ExtraData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ExtraData) | |||
totalReadN += readN | |||
} | |||
return msg, totalReadN | |||
} | |||
func GlobalCommitRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := AbstractGlobalEndRequestDecoder(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
msg := message.GlobalCommitRequest{AbstractGlobalEndRequest: abstractGlobalEndRequest} | |||
return msg, totalReadN | |||
} | |||
func GlobalCommitResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractGlobalEndResponseDecoder(in) | |||
abstractGlobalEndResponse := resp.(message.AbstractGlobalEndResponse) | |||
msg := message.GlobalCommitResponse{AbstractGlobalEndResponse: abstractGlobalEndResponse} | |||
return msg, totalReadN | |||
} | |||
func GlobalLockQueryRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := BranchRegisterRequestDecoder(in) | |||
branchRegisterRequest := req.(message.BranchRegisterRequest) | |||
msg := message.GlobalLockQueryRequest{BranchRegisterRequest: branchRegisterRequest} | |||
return msg, totalReadN | |||
} | |||
func GlobalLockQueryResponseDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.GlobalLockQueryResponse{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
resultCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.ResultCode = message.ResultCode(resultCode) | |||
if msg.ResultCode == message.ResultCodeFailed { | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Msg, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
} | |||
exceptionCode, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.TransactionExceptionCode = transaction.TransactionExceptionCode(exceptionCode) | |||
lockable, readN, _ := r.ReadUint16() | |||
totalReadN += readN | |||
if lockable == uint16(1) { | |||
msg.Lockable = true | |||
} else if lockable == uint16(0) { | |||
msg.Lockable = false | |||
} | |||
return msg, totalReadN | |||
} | |||
func GlobalReportRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.GlobalReportRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.Xid, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ExtraData = make([]byte, int(length16)) | |||
readN, _ := r.Read(msg.ExtraData) | |||
totalReadN += readN | |||
} | |||
globalStatus, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.GlobalStatus = transaction.GlobalStatus(globalStatus) | |||
return msg, totalReadN | |||
} | |||
func GlobalReportResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractGlobalEndResponseDecoder(in) | |||
abstractGlobalEndResponse := resp.(message.AbstractGlobalEndResponse) | |||
msg := message.GlobalReportResponse{AbstractGlobalEndResponse: abstractGlobalEndResponse} | |||
return msg, totalReadN | |||
} | |||
func GlobalRollbackRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := AbstractGlobalEndRequestDecoder(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
msg := message.GlobalRollbackRequest{AbstractGlobalEndRequest: abstractGlobalEndRequest} | |||
return msg, totalReadN | |||
} | |||
func GlobalRollbackResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractGlobalEndResponseDecoder(in) | |||
abstractGlobalEndResponse := resp.(message.AbstractGlobalEndResponse) | |||
msg := message.GlobalRollbackResponse{AbstractGlobalEndResponse: abstractGlobalEndResponse} | |||
return msg, totalReadN | |||
} | |||
func GlobalStatusRequestDecoder(in []byte) (interface{}, int) { | |||
req, totalReadN := AbstractGlobalEndRequestDecoder(in) | |||
abstractGlobalEndRequest := req.(message.AbstractGlobalEndRequest) | |||
msg := message.GlobalStatusRequest{AbstractGlobalEndRequest: abstractGlobalEndRequest} | |||
return msg, totalReadN | |||
} | |||
func GlobalStatusResponseDecoder(in []byte) (interface{}, int) { | |||
resp, totalReadN := AbstractGlobalEndResponseDecoder(in) | |||
abstractGlobalEndResponse := resp.(message.AbstractGlobalEndResponse) | |||
msg := message.GlobalStatusResponse{AbstractGlobalEndResponse: abstractGlobalEndResponse} | |||
return msg, totalReadN | |||
} | |||
func UndoLogDeleteRequestDecoder(in []byte) (interface{}, int) { | |||
var ( | |||
length16 uint16 = 0 | |||
readN = 0 | |||
totalReadN = 0 | |||
) | |||
msg := message.UndoLogDeleteRequest{} | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(in)} | |||
branchType, _ := r.ReadByte() | |||
totalReadN += 1 | |||
msg.BranchType = model2.BranchType(branchType) | |||
length16, readN, _ = r.ReadUint16() | |||
totalReadN += readN | |||
if length16 > 0 { | |||
msg.ResourceId, readN, _ = r.ReadString(int(length16)) | |||
totalReadN += readN | |||
} | |||
day, readN, _ := r.ReadInt16() | |||
msg.SaveDays = message.MessageType(day) | |||
totalReadN += readN | |||
return msg, totalReadN | |||
} |
@@ -1,563 +0,0 @@ | |||
package codec | |||
import ( | |||
"bytes" | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
import ( | |||
"vimagination.zapto.org/byteio" | |||
) | |||
// TODO 待重构 | |||
func AbstractResultMessageEncoder(in interface{}) []byte { | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
msgs := in.(message.AbstractResultMessage) | |||
w.WriteByte(byte(msgs.ResultCode)) | |||
if msgs.ResultCode == message.ResultCodeFailed { | |||
var msg string | |||
if msgs.Msg != "" { | |||
if len(msgs.Msg) > 128 { | |||
msg = msgs.Msg[:128] | |||
} else { | |||
msg = msgs.Msg | |||
} | |||
// 暂时不考虑 msg.Msg 包含中文的情况,这样字符串的长度就是 byte 数组的长度 | |||
w.WriteInt16(int16(len(msg))) | |||
w.WriteString(msg) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
} | |||
return b.Bytes() | |||
} | |||
func MergedWarpMessageEncoder(in interface{}) []byte { | |||
var ( | |||
b bytes.Buffer | |||
result = make([]byte, 0) | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.MergedWarpMessage) | |||
w.WriteInt16(int16(len(req.Msgs))) | |||
for _, msg := range req.Msgs { | |||
encoder := getMessageEncoder(msg.GetTypeCode()) | |||
if encoder != nil { | |||
data := encoder(msg) | |||
w.WriteInt16(int16(msg.GetTypeCode())) | |||
w.Write(data) | |||
} | |||
} | |||
size := uint32(b.Len()) | |||
result = append(result, []byte{byte(size >> 24), byte(size >> 16), byte(size >> 8), byte(size)}...) | |||
result = append(result, b.Bytes()...) | |||
if len(req.Msgs) > 20 { | |||
log.Debugf("msg in one packet: %s ,buffer size: %s", len(req.Msgs), size) | |||
} | |||
return result | |||
} | |||
func MergeResultMessageEncoder(in interface{}) []byte { | |||
var ( | |||
b bytes.Buffer | |||
result = make([]byte, 0) | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.MergeResultMessage) | |||
w.WriteInt16(int16(len(req.Msgs))) | |||
for _, msg := range req.Msgs { | |||
encoder := getMessageEncoder(msg.GetTypeCode()) | |||
if encoder != nil { | |||
data := encoder(msg) | |||
w.WriteInt16(int16(msg.GetTypeCode())) | |||
w.Write(data) | |||
} | |||
} | |||
size := uint32(b.Len()) | |||
result = append(result, []byte{byte(size >> 24), byte(size >> 16), byte(size >> 8), byte(size)}...) | |||
result = append(result, b.Bytes()...) | |||
if len(req.Msgs) > 20 { | |||
log.Debugf("msg in one packet: %s ,buffer size: %s", len(req.Msgs), size) | |||
} | |||
return result | |||
} | |||
func AbstractIdentifyRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req := in.(message.AbstractIdentifyRequest) | |||
if req.Version != "" { | |||
w.WriteInt16(int16(len(req.Version))) | |||
w.WriteString(req.Version) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.ApplicationId != "" { | |||
w.WriteInt16(int16(len(req.ApplicationId))) | |||
w.WriteString(req.ApplicationId) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.TransactionServiceGroup != "" { | |||
w.WriteInt16(int16(len(req.TransactionServiceGroup))) | |||
w.WriteString(req.TransactionServiceGroup) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.ExtraData != nil { | |||
w.WriteUint16(uint16(len(req.ExtraData))) | |||
w.Write(req.ExtraData) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
return b.Bytes() | |||
} | |||
func AbstractIdentifyResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.AbstractIdentifyResponse) | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
if resp.Identified { | |||
w.WriteByte(byte(1)) | |||
} else { | |||
w.WriteByte(byte(0)) | |||
} | |||
if resp.Version != "" { | |||
w.WriteInt16(int16(len(resp.Version))) | |||
w.WriteString(resp.Version) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
return b.Bytes() | |||
} | |||
func RegisterRMRequestEncoder(in interface{}) []byte { | |||
req := in.(message.RegisterRMRequest) | |||
data := AbstractIdentifyRequestEncoder(req.AbstractIdentifyRequest) | |||
var ( | |||
zero32 int32 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
if req.ResourceIds != "" { | |||
w.WriteInt32(int32(len(req.ResourceIds))) | |||
w.WriteString(req.ResourceIds) | |||
} else { | |||
w.WriteInt32(zero32) | |||
} | |||
result := append(data, b.Bytes()...) | |||
return result | |||
} | |||
func RegisterRMResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.RegisterRMResponse) | |||
return AbstractIdentifyResponseEncoder(resp.AbstractIdentifyResponse) | |||
} | |||
func RegisterTMRequestEncoder(in interface{}) []byte { | |||
req := in.(message.RegisterTMRequest) | |||
return AbstractIdentifyRequestEncoder(req.AbstractIdentifyRequest) | |||
} | |||
func RegisterTMResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.RegisterTMResponse) | |||
return AbstractIdentifyResponseEncoder(resp.AbstractIdentifyResponse) | |||
} | |||
func AbstractTransactionResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.AbstractTransactionResponse) | |||
data := AbstractResultMessageEncoder(resp.AbstractResultMessage) | |||
result := append(data, byte(resp.TransactionExceptionCode)) | |||
return result | |||
} | |||
func AbstractBranchEndRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero32 int32 = 0 | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.AbstractBranchEndRequest) | |||
if req.Xid != "" { | |||
w.WriteInt16(int16(len(req.Xid))) | |||
w.WriteString(req.Xid) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
w.WriteInt64(req.BranchId) | |||
w.WriteByte(byte(req.BranchType)) | |||
if req.ResourceId != "" { | |||
w.WriteInt16(int16(len(req.ResourceId))) | |||
w.WriteString(req.ResourceId) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.ApplicationData != nil { | |||
w.WriteUint32(uint32(len(req.ApplicationData))) | |||
w.Write(req.ApplicationData) | |||
} else { | |||
w.WriteInt32(zero32) | |||
} | |||
return b.Bytes() | |||
} | |||
func AbstractBranchEndResponseEncoder(in interface{}) []byte { | |||
resp, _ := in.(message.AbstractBranchEndResponse) | |||
data := AbstractTransactionResponseEncoder(resp.AbstractTransactionResponse) | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
if resp.Xid != "" { | |||
w.WriteInt16(int16(len(resp.Xid))) | |||
w.WriteString(resp.Xid) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
w.WriteInt64(resp.BranchId) | |||
w.WriteByte(byte(resp.BranchStatus)) | |||
result := append(data, b.Bytes()...) | |||
return result | |||
} | |||
func AbstractGlobalEndRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.AbstractGlobalEndRequest) | |||
if req.Xid != "" { | |||
w.WriteInt16(int16(len(req.Xid))) | |||
w.WriteString(req.Xid) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.ExtraData != nil { | |||
w.WriteUint16(uint16(len(req.ExtraData))) | |||
w.Write(req.ExtraData) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
return b.Bytes() | |||
} | |||
func AbstractGlobalEndResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.AbstractGlobalEndResponse) | |||
data := AbstractTransactionResponseEncoder(resp.AbstractTransactionResponse) | |||
result := append(data, byte(resp.GlobalStatus)) | |||
return result | |||
} | |||
func BranchCommitRequestEncoder(in interface{}) []byte { | |||
req := in.(message.BranchCommitRequest) | |||
return AbstractBranchEndRequestEncoder(req.AbstractBranchEndRequest) | |||
} | |||
func BranchCommitResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.BranchCommitResponse) | |||
return AbstractBranchEndResponseEncoder(resp.AbstractBranchEndResponse) | |||
} | |||
func BranchRegisterRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero32 int32 = 0 | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.BranchRegisterRequest) | |||
if req.Xid != "" { | |||
w.WriteInt16(int16(len(req.Xid))) | |||
w.WriteString(req.Xid) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
w.WriteByte(byte(req.BranchType)) | |||
if req.ResourceId != "" { | |||
w.WriteInt16(int16(len(req.ResourceId))) | |||
w.WriteString(req.ResourceId) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.LockKey != "" { | |||
w.WriteInt32(int32(len(req.LockKey))) | |||
w.WriteString(req.LockKey) | |||
} else { | |||
w.WriteInt32(zero32) | |||
} | |||
if req.ApplicationData != nil { | |||
w.WriteUint32(uint32(len(req.ApplicationData))) | |||
w.Write(req.ApplicationData) | |||
} else { | |||
w.WriteInt32(zero32) | |||
} | |||
return b.Bytes() | |||
} | |||
func BranchRegisterResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.BranchRegisterResponse) | |||
data := AbstractTransactionResponseEncoder(resp.AbstractTransactionResponse) | |||
c := uint64(resp.BranchId) | |||
branchIdBytes := []byte{ | |||
byte(c >> 56), | |||
byte(c >> 48), | |||
byte(c >> 40), | |||
byte(c >> 32), | |||
byte(c >> 24), | |||
byte(c >> 16), | |||
byte(c >> 8), | |||
byte(c), | |||
} | |||
result := append(data, branchIdBytes...) | |||
return result | |||
} | |||
func BranchReportRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero32 int32 = 0 | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.BranchReportRequest) | |||
if req.Xid != "" { | |||
w.WriteInt16(int16(len(req.Xid))) | |||
w.WriteString(req.Xid) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
w.WriteInt64(req.BranchId) | |||
w.WriteByte(byte(req.Status)) | |||
if req.ResourceId != "" { | |||
w.WriteInt16(int16(len(req.ResourceId))) | |||
w.WriteString(req.ResourceId) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if req.ApplicationData != nil { | |||
w.WriteUint32(uint32(len(req.ApplicationData))) | |||
w.Write(req.ApplicationData) | |||
} else { | |||
w.WriteInt32(zero32) | |||
} | |||
w.WriteByte(byte(req.BranchType)) | |||
return b.Bytes() | |||
} | |||
func BranchReportResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.BranchReportResponse) | |||
return AbstractTransactionResponseEncoder(resp.AbstractTransactionResponse) | |||
} | |||
func BranchRollbackRequestEncoder(in interface{}) []byte { | |||
req := in.(message.BranchRollbackRequest) | |||
return AbstractBranchEndRequestEncoder(req.AbstractBranchEndRequest) | |||
} | |||
func BranchRollbackResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.BranchRollbackResponse) | |||
return AbstractBranchEndResponseEncoder(resp.AbstractBranchEndResponse) | |||
} | |||
func GlobalBeginRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.GlobalBeginRequest) | |||
w.WriteInt32(req.Timeout) | |||
if req.TransactionName != "" { | |||
w.WriteInt16(int16(len(req.TransactionName))) | |||
w.WriteString(req.TransactionName) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
return b.Bytes() | |||
} | |||
func GlobalBeginResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.GlobalBeginResponse) | |||
data := AbstractTransactionResponseEncoder(resp.AbstractTransactionResponse) | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
if resp.Xid != "" { | |||
w.WriteInt16(int16(len(resp.Xid))) | |||
w.WriteString(resp.Xid) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
if resp.ExtraData != nil { | |||
w.WriteUint16(uint16(len(resp.ExtraData))) | |||
w.Write(resp.ExtraData) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
result := append(data, b.Bytes()...) | |||
return result | |||
} | |||
func GlobalCommitRequestEncoder(in interface{}) []byte { | |||
req := in.(message.GlobalCommitRequest) | |||
return AbstractGlobalEndRequestEncoder(req.AbstractGlobalEndRequest) | |||
} | |||
func GlobalCommitResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.GlobalCommitResponse) | |||
return AbstractGlobalEndResponseEncoder(resp.AbstractGlobalEndResponse) | |||
} | |||
func GlobalLockQueryRequestEncoder(in interface{}) []byte { | |||
return BranchRegisterRequestEncoder(in) | |||
} | |||
func GlobalLockQueryResponseEncoder(in interface{}) []byte { | |||
resp, _ := in.(message.GlobalLockQueryResponse) | |||
data := AbstractTransactionResponseEncoder(resp.AbstractTransactionResponse) | |||
var result []byte | |||
if resp.Lockable { | |||
result = append(data, byte(0), byte(1)) | |||
} else { | |||
result = append(data, byte(0), byte(0)) | |||
} | |||
return result | |||
} | |||
func GlobalReportRequestEncoder(in interface{}) []byte { | |||
req, _ := in.(message.GlobalReportRequest) | |||
data := AbstractGlobalEndRequestEncoder(req.AbstractGlobalEndRequest) | |||
result := append(data, byte(req.GlobalStatus)) | |||
return result | |||
} | |||
func GlobalReportResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.GlobalReportResponse) | |||
return AbstractGlobalEndResponseEncoder(resp.AbstractGlobalEndResponse) | |||
} | |||
func GlobalRollbackRequestEncoder(in interface{}) []byte { | |||
req := in.(message.GlobalRollbackRequest) | |||
return AbstractGlobalEndRequestEncoder(req.AbstractGlobalEndRequest) | |||
} | |||
func GlobalRollbackResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.GlobalRollbackResponse) | |||
return AbstractGlobalEndResponseEncoder(resp.AbstractGlobalEndResponse) | |||
} | |||
func GlobalStatusRequestEncoder(in interface{}) []byte { | |||
req := in.(message.GlobalStatusRequest) | |||
return AbstractGlobalEndRequestEncoder(req.AbstractGlobalEndRequest) | |||
} | |||
func GlobalStatusResponseEncoder(in interface{}) []byte { | |||
resp := in.(message.GlobalStatusResponse) | |||
return AbstractGlobalEndResponseEncoder(resp.AbstractGlobalEndResponse) | |||
} | |||
func UndoLogDeleteRequestEncoder(in interface{}) []byte { | |||
var ( | |||
zero16 int16 = 0 | |||
b bytes.Buffer | |||
) | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
req, _ := in.(message.UndoLogDeleteRequest) | |||
w.WriteByte(byte(req.BranchType)) | |||
if req.ResourceId != "" { | |||
w.WriteInt16(int16(len(req.ResourceId))) | |||
w.WriteString(req.ResourceId) | |||
} else { | |||
w.WriteInt16(zero16) | |||
} | |||
w.WriteInt16(int16(req.SaveDays)) | |||
return b.Bytes() | |||
} |
@@ -6,6 +6,7 @@ import ( | |||
) | |||
type AbstractBranchEndRequest struct { | |||
MessageTypeAware | |||
Xid string | |||
BranchId int64 | |||
BranchType model2.BranchType | |||
@@ -2,10 +2,13 @@ package resource | |||
import ( | |||
"context" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"sync" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
) | |||
// Resource that can be managed by Resource Manager and involved into global transaction | |||
type Resource interface { | |||
GetResourceGroupId() string | |||
@@ -1,22 +1,58 @@ | |||
package transaction | |||
package seatactx | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
"github.com/seata/seata-go/pkg/rm/tcc/api" | |||
) | |||
type ContextVariable struct { | |||
TxName string | |||
Xid string | |||
Status *GlobalStatus | |||
Role *GlobalTransactionRole | |||
Status *transaction.GlobalStatus | |||
TxRole *transaction.GlobalTransactionRole | |||
BusinessActionContext *api.BusinessActionContext | |||
TxStatus *transaction.GlobalStatus | |||
} | |||
func InitSeataContext(ctx context.Context) context.Context { | |||
return context.WithValue(ctx, common.CONTEXT_VARIABLE, &ContextVariable{}) | |||
} | |||
func GetTxStatus(ctx context.Context) *transaction.GlobalStatus { | |||
variable := ctx.Value(common.CONTEXT_VARIABLE) | |||
if variable == nil { | |||
return nil | |||
} | |||
return variable.(*ContextVariable).TxStatus | |||
} | |||
func SetTxStatus(ctx context.Context, status transaction.GlobalStatus) { | |||
variable := ctx.Value(common.CONTEXT_VARIABLE) | |||
if variable != nil { | |||
variable.(*ContextVariable).TxStatus = &status | |||
} | |||
} | |||
func GetTxName(ctx context.Context) string { | |||
variable := ctx.Value(common.CONTEXT_VARIABLE) | |||
if variable == nil { | |||
return "" | |||
} | |||
return variable.(*ContextVariable).TxName | |||
} | |||
func SetTxName(ctx context.Context, name string) { | |||
variable := ctx.Value(common.TccBusinessActionContext) | |||
if variable != nil { | |||
variable.(*ContextVariable).TxName = name | |||
} | |||
} | |||
func IsSeataContext(ctx context.Context) bool { | |||
return ctx.Value(common.CONTEXT_VARIABLE) != nil | |||
} | |||
@@ -36,18 +72,18 @@ func SetBusinessActionContext(ctx context.Context, businessActionContext *api.Bu | |||
} | |||
} | |||
func GetTransactionRole(ctx context.Context) *GlobalTransactionRole { | |||
func GetTransactionRole(ctx context.Context) *transaction.GlobalTransactionRole { | |||
variable := ctx.Value(common.CONTEXT_VARIABLE) | |||
if variable == nil { | |||
return nil | |||
} | |||
return variable.(*ContextVariable).Role | |||
return variable.(*ContextVariable).TxRole | |||
} | |||
func SetTransactionRole(ctx context.Context, role GlobalTransactionRole) { | |||
func SetTransactionRole(ctx context.Context, role transaction.GlobalTransactionRole) { | |||
variable := ctx.Value(common.CONTEXT_VARIABLE) | |||
if variable != nil { | |||
variable.(*ContextVariable).Role = &role | |||
variable.(*ContextVariable).TxRole = &role | |||
} | |||
} | |||
@@ -0,0 +1,81 @@ | |||
package api | |||
import ( | |||
"context" | |||
"fmt" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/seatactx" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
"github.com/seata/seata-go/pkg/protocol/transaction/manager" | |||
) | |||
type TransactionalExecutor interface { | |||
Execute(ctx context.Context, param interface{}) (interface{}, error) | |||
GetTransactionInfo() transaction.TransactionInfo | |||
} | |||
func Begin(ctx context.Context, name string) context.Context { | |||
if !seatactx.IsSeataContext(ctx) { | |||
ctx = seatactx.InitSeataContext(ctx) | |||
} | |||
seatactx.SetTxName(ctx, name) | |||
if seatactx.GetTransactionRole(ctx) == nil { | |||
seatactx.SetTransactionRole(ctx, transaction.LAUNCHER) | |||
} | |||
var tx *manager.GlobalTransaction | |||
if seatactx.HasXID(ctx) { | |||
tx = &manager.GlobalTransaction{ | |||
Xid: seatactx.GetXID(ctx), | |||
Status: transaction.GlobalStatusBegin, | |||
Role: transaction.PARTICIPANT, | |||
} | |||
seatactx.SetTxStatus(ctx, transaction.GlobalStatusBegin) | |||
} | |||
// todo: Handle the transaction propagation. | |||
if tx == nil { | |||
tx = &manager.GlobalTransaction{ | |||
Xid: seatactx.GetXID(ctx), | |||
Status: transaction.GlobalStatusUnKnown, | |||
Role: transaction.LAUNCHER, | |||
} | |||
seatactx.SetTxStatus(ctx, transaction.GlobalStatusUnKnown) | |||
} | |||
// todo timeout should read from config | |||
err := manager.GetGlobalTransactionManager().Begin(ctx, tx, 50, name) | |||
if err != nil { | |||
panic(fmt.Sprintf("transactionTemplate: begin transaction failed, error %v", err)) | |||
} | |||
return ctx | |||
} | |||
// commit global transaction | |||
func CommitOrRollback(ctx context.Context, err error) error { | |||
tx := &manager.GlobalTransaction{ | |||
Xid: seatactx.GetXID(ctx), | |||
Status: *seatactx.GetTxStatus(ctx), | |||
Role: *seatactx.GetTransactionRole(ctx), | |||
} | |||
var resp error | |||
if err == nil { | |||
resp = manager.GetGlobalTransactionManager().Commit(ctx, tx) | |||
if resp != nil { | |||
log.Infof("transactionTemplate: commit transaction failed, error %v", err) | |||
} | |||
} else { | |||
resp = manager.GetGlobalTransactionManager().Rollback(ctx, tx) | |||
if resp != nil { | |||
log.Infof("transactionTemplate: Rollback transaction failed, error %v", err) | |||
} | |||
} | |||
return resp | |||
} |
@@ -1,96 +0,0 @@ | |||
package executor | |||
import ( | |||
"context" | |||
"github.com/pkg/errors" | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
"github.com/seata/seata-go/pkg/protocol/transaction/manager" | |||
"sync" | |||
) | |||
type TransactionalExecutor interface { | |||
Execute(ctx context.Context, param interface{}) (interface{}, error) | |||
GetTransactionInfo() transaction.TransactionInfo | |||
} | |||
var ( | |||
transactionTemplate *TransactionTemplate | |||
onceTransactionTemplate = &sync.Once{} | |||
) | |||
func GetTransactionTemplate() *TransactionTemplate { | |||
if transactionTemplate == nil { | |||
onceTransactionTemplate.Do(func() { | |||
transactionTemplate = &TransactionTemplate{} | |||
}) | |||
} | |||
return transactionTemplate | |||
} | |||
type TransactionTemplate struct { | |||
} | |||
func (t *TransactionTemplate) Execute(ctx context.Context, business TransactionalExecutor, param interface{}) (interface{}, error) { | |||
if !transaction.IsSeataContext(ctx) { | |||
err := errors.New("context should be inited as seata context!") | |||
log.Error(err) | |||
return nil, err | |||
} | |||
if transaction.GetTransactionRole(ctx) == nil { | |||
transaction.SetTransactionRole(ctx, transaction.LAUNCHER) | |||
} | |||
var tx *manager.GlobalTransaction | |||
if transaction.HasXID(ctx) { | |||
tx = &manager.GlobalTransaction{ | |||
Xid: transaction.GetXID(ctx), | |||
Status: transaction.Begin, | |||
Role: transaction.PARTICIPANT, | |||
} | |||
} | |||
// todo: Handle the transaction propagation. | |||
if tx == nil { | |||
tx = &manager.GlobalTransaction{ | |||
Xid: transaction.GetXID(ctx), | |||
Status: transaction.UnKnown, | |||
Role: transaction.LAUNCHER, | |||
} | |||
} | |||
// todo: set current tx config to holder | |||
// begin global transaction | |||
err := t.BeginTransaction(ctx, tx, business.GetTransactionInfo().TimeOut, business.GetTransactionInfo().Name) | |||
if err != nil { | |||
log.Infof("transactionTemplate: begin transaction failed, error %v", err) | |||
return nil, err | |||
} | |||
// do your business | |||
res, err := business.Execute(ctx, param) | |||
if err != nil { | |||
log.Infof("transactionTemplate: execute business failed, error %v", err) | |||
return nil, manager.GetGlobalTransactionManager().Rollback(ctx, tx) | |||
} | |||
// commit global transaction | |||
err = t.CommitTransaction(ctx, tx) | |||
if err != nil { | |||
log.Infof("transactionTemplate: commit transaction failed, error %v", err) | |||
// rollback transaction | |||
return nil, manager.GetGlobalTransactionManager().Rollback(ctx, tx) | |||
} | |||
return res, err | |||
} | |||
func (TransactionTemplate) BeginTransaction(ctx context.Context, tx *manager.GlobalTransaction, timeout int32, name string) error { | |||
return manager.GetGlobalTransactionManager().Begin(ctx, tx, timeout, name) | |||
} | |||
func (TransactionTemplate) CommitTransaction(ctx context.Context, tx *manager.GlobalTransaction) error { | |||
return manager.GetGlobalTransactionManager().Commit(ctx, tx) | |||
} |
@@ -3,13 +3,20 @@ package manager | |||
import ( | |||
"context" | |||
"fmt" | |||
"sync" | |||
) | |||
import ( | |||
"github.com/pkg/errors" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/seatactx" | |||
"github.com/seata/seata-go/pkg/protocol/transaction" | |||
"github.com/seata/seata-go/pkg/remoting/getty" | |||
"github.com/seata/seata-go/pkg/tm/api" | |||
"sync" | |||
) | |||
type GlobalTransaction struct { | |||
@@ -39,7 +46,7 @@ type GlobalTransactionManager struct { | |||
// Begin a new global transaction with given timeout and given name. | |||
func (g *GlobalTransactionManager) Begin(ctx context.Context, gtr *GlobalTransaction, timeout int32, name string) error { | |||
if gtr.Role != transaction.LAUNCHER { | |||
log.Infof("Ignore Begin(): just involved in global transaction %s", gtr.Xid) | |||
log.Infof("Ignore GlobalStatusBegin(): just involved in global transaction %s", gtr.Xid) | |||
return nil | |||
} | |||
if gtr.Xid != "" { | |||
@@ -61,9 +68,9 @@ func (g *GlobalTransactionManager) Begin(ctx context.Context, gtr *GlobalTransac | |||
} | |||
log.Infof("GlobalBeginRequest success, xid %s, res %v", gtr.Xid, res) | |||
gtr.Status = transaction.Begin | |||
gtr.Status = transaction.GlobalStatusBegin | |||
gtr.Xid = res.(message.GlobalBeginResponse).Xid | |||
transaction.SetXID(ctx, res.(message.GlobalBeginResponse).Xid) | |||
seatactx.SetXID(ctx, res.(message.GlobalBeginResponse).Xid) | |||
return nil | |||
} | |||
@@ -98,7 +105,7 @@ func (g *GlobalTransactionManager) Commit(ctx context.Context, gtr *GlobalTransa | |||
if err == nil && res != nil { | |||
gtr.Status = res.(message.GlobalCommitResponse).GlobalStatus | |||
} | |||
transaction.UnbindXid(ctx) | |||
seatactx.UnbindXid(ctx) | |||
log.Infof("GlobalCommitRequest commit success, xid %s", gtr.Xid) | |||
return err | |||
} | |||
@@ -134,7 +141,7 @@ func (g *GlobalTransactionManager) Rollback(ctx context.Context, gtr *GlobalTran | |||
if err == nil && res != nil { | |||
gtr.Status = res.(message.GlobalRollbackResponse).GlobalStatus | |||
} | |||
transaction.UnbindXid(ctx) | |||
seatactx.UnbindXid(ctx) | |||
return err | |||
} | |||
@@ -8,7 +8,7 @@ const ( | |||
) | |||
type TransactionManager interface { | |||
// Begin a new global transaction. | |||
// GlobalStatusBegin a new global transaction. | |||
Begin(applicationId, transactionServiceGroup, name string, timeout int64) (string, error) | |||
// Global commit. | |||
@@ -8,95 +8,95 @@ const ( | |||
* Un known global status. | |||
*/ | |||
// Unknown | |||
UnKnown GlobalStatus = 0 | |||
GlobalStatusUnKnown GlobalStatus = 0 | |||
/** | |||
* The Begin. | |||
* The GlobalStatusBegin. | |||
*/ | |||
// PHASE 1: can accept new branch registering. | |||
Begin GlobalStatus = 1 | |||
GlobalStatusBegin GlobalStatus = 1 | |||
/** | |||
* PHASE 2: Running Status: may be changed any time. | |||
*/ | |||
// Committing. | |||
Committing GlobalStatus = 2 | |||
GlobalStatusCommitting GlobalStatus = 2 | |||
/** | |||
* The Commit retrying. | |||
*/ | |||
// Retrying commit after a recoverable failure. | |||
CommitRetrying GlobalStatus = 3 | |||
GlobalStatusCommitRetrying GlobalStatus = 3 | |||
/** | |||
* Rollbacking global status. | |||
*/ | |||
// Rollbacking | |||
Rollbacking GlobalStatus = 4 | |||
GlobalStatusRollbacking GlobalStatus = 4 | |||
/** | |||
* The Rollback retrying. | |||
*/ | |||
// Retrying rollback after a recoverable failure. | |||
RollbackRetrying GlobalStatus = 5 | |||
GlobalStatusRollbackRetrying GlobalStatus = 5 | |||
/** | |||
* The Timeout rollbacking. | |||
*/ | |||
// Rollbacking since timeout | |||
TimeoutRollbacking GlobalStatus = 6 | |||
GlobalStatusTimeoutRollbacking GlobalStatus = 6 | |||
/** | |||
* The Timeout rollback retrying. | |||
*/ | |||
// Retrying rollback GlobalStatus = since timeout) after a recoverable failure. | |||
TimeoutRollbackRetrying GlobalStatus = 7 | |||
GlobalStatusTimeoutRollbackRetrying GlobalStatus = 7 | |||
/** | |||
* All branches can be async committed. The committing is NOT done yet, but it can be seen as committed for TM/RM | |||
* client. | |||
*/ | |||
AsyncCommitting GlobalStatus = 8 | |||
GlobalStatusAsyncCommitting GlobalStatus = 8 | |||
/** | |||
* PHASE 2: Final Status: will NOT change any more. | |||
*/ | |||
// Finally: global transaction is successfully committed. | |||
Committed GlobalStatus = 9 | |||
GlobalStatusCommitted GlobalStatus = 9 | |||
/** | |||
* The Commit failed. | |||
*/ | |||
// Finally: failed to commit | |||
CommitFailed GlobalStatus = 10 | |||
GlobalStatusCommitFailed GlobalStatus = 10 | |||
/** | |||
* The Rollbacked. | |||
*/ | |||
// Finally: global transaction is successfully rollbacked. | |||
Rollbacked GlobalStatus = 11 | |||
GlobalStatusRollbacked GlobalStatus = 11 | |||
/** | |||
* The Rollback failed. | |||
*/ | |||
// Finally: failed to rollback | |||
RollbackFailed GlobalStatus = 12 | |||
GlobalStatusRollbackFailed GlobalStatus = 12 | |||
/** | |||
* The Timeout rollbacked. | |||
*/ | |||
// Finally: global transaction is successfully rollbacked since timeout. | |||
TimeoutRollbacked GlobalStatus = 13 | |||
GlobalStatusTimeoutRollbacked GlobalStatus = 13 | |||
/** | |||
* The Timeout rollback failed. | |||
*/ | |||
// Finally: failed to rollback since timeout | |||
TimeoutRollbackFailed GlobalStatus = 14 | |||
GlobalStatusTimeoutRollbackFailed GlobalStatus = 14 | |||
/** | |||
* The Finished. | |||
*/ | |||
// Not managed in session MAP any more | |||
Finished GlobalStatus = 15 | |||
GlobalStatusFinished GlobalStatus = 15 | |||
) |
@@ -1,7 +1,6 @@ | |||
package getty | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"sync" | |||
"time" | |||
) | |||
@@ -12,6 +11,7 @@ import ( | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/codec" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
var ( | |||
@@ -44,7 +44,7 @@ func (client *GettyRemotingClient) SendAsyncRequest(msg interface{}) error { | |||
rpcMessage := message.RpcMessage{ | |||
ID: int32(client.idGenerator.Inc()), | |||
Type: msgType, | |||
Codec: codec.SEATA, | |||
Codec: byte(codec.CodeTypeSeata), | |||
Compressor: 0, | |||
Body: msg, | |||
} | |||
@@ -55,7 +55,7 @@ func (client *GettyRemotingClient) SendAsyncResponse(msg interface{}) error { | |||
rpcMessage := message.RpcMessage{ | |||
ID: int32(client.idGenerator.Inc()), | |||
Type: message.GettyRequestType_Response, | |||
Codec: codec.SEATA, | |||
Codec: byte(codec.CodeTypeSeata), | |||
Compressor: 0, | |||
Body: msg, | |||
} | |||
@@ -66,7 +66,7 @@ func (client *GettyRemotingClient) SendSyncRequest(msg interface{}) (interface{} | |||
rpcMessage := message.RpcMessage{ | |||
ID: int32(client.idGenerator.Inc()), | |||
Type: message.GettyRequestType_RequestSync, | |||
Codec: codec.SEATA, | |||
Codec: byte(codec.CodeTypeSeata), | |||
Compressor: 0, | |||
Body: msg, | |||
} | |||
@@ -77,7 +77,7 @@ func (client *GettyRemotingClient) SendSyncRequestWithTimeout(msg interface{}, t | |||
rpcMessage := message.RpcMessage{ | |||
ID: int32(client.idGenerator.Inc()), | |||
Type: message.GettyRequestType_RequestSync, | |||
Codec: codec.SEATA, | |||
Codec: byte(codec.CodeTypeSeata), | |||
Compressor: 0, | |||
Body: msg, | |||
} | |||
@@ -1,8 +1,6 @@ | |||
package getty | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"sync" | |||
"time" | |||
) | |||
@@ -15,6 +13,11 @@ import ( | |||
"github.com/pkg/errors" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
const ( | |||
RPC_REQUEST_TIMEOUT = 30 * time.Second | |||
) | |||
@@ -42,17 +45,17 @@ func GetGettyRemotingInstance() *GettyRemoting { | |||
} | |||
func (client *GettyRemoting) SendSync(msg message.RpcMessage) (interface{}, error) { | |||
ss := clientSessionManager.AcquireGettySession() | |||
ss := sessionManager.AcquireGettySession() | |||
return client.sendAsync(ss, msg, RPC_REQUEST_TIMEOUT) | |||
} | |||
func (client *GettyRemoting) SendSyncWithTimeout(msg message.RpcMessage, timeout time.Duration) (interface{}, error) { | |||
ss := clientSessionManager.AcquireGettySession() | |||
ss := sessionManager.AcquireGettySession() | |||
return client.sendAsync(ss, msg, timeout) | |||
} | |||
func (client *GettyRemoting) SendASync(msg message.RpcMessage) error { | |||
ss := clientSessionManager.AcquireGettySession() | |||
ss := sessionManager.AcquireGettySession() | |||
_, err := client.sendAsync(ss, msg, 0*time.Second) | |||
return err | |||
} | |||
@@ -131,7 +134,7 @@ func (client *GettyRemoting) NotifyRpcMessageResponse(rpcMessage message.RpcMess | |||
// todo add messageFuture.Err | |||
//messageFuture.Err = rpcMessage.Err | |||
messageFuture.Done <- true | |||
//client.msgFutures.Delete(rpcMessage.ID) | |||
//client.msgFutures.Delete(rpcMessage.RequestID) | |||
} else { | |||
log.Infof("msg: {} is not found in msgFutures.", rpcMessage.ID) | |||
} | |||
@@ -2,9 +2,6 @@ package getty | |||
import ( | |||
"context" | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/remoting/processor" | |||
"sync" | |||
) | |||
@@ -15,7 +12,10 @@ import ( | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/config" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/remoting/processor" | |||
) | |||
var ( | |||
@@ -47,7 +47,7 @@ func GetGettyClientHandlerInstance() *gettyClientHandler { | |||
} | |||
func (client *gettyClientHandler) OnOpen(session getty.Session) error { | |||
clientSessionManager.RegisterGettySession(session) | |||
sessionManager.RegisterGettySession(session) | |||
go func() { | |||
request := message.RegisterTMRequest{AbstractIdentifyRequest: message.AbstractIdentifyRequest{ | |||
Version: client.conf.SeataVersion, | |||
@@ -58,7 +58,7 @@ func (client *gettyClientHandler) OnOpen(session getty.Session) error { | |||
//client.sendAsyncRequestWithResponse(session, request, RPC_REQUEST_TIMEOUT) | |||
if err != nil { | |||
log.Error("OnOpen error: {%#v}", err.Error()) | |||
clientSessionManager.ReleaseGettySession(session) | |||
sessionManager.ReleaseGettySession(session) | |||
return | |||
} | |||
@@ -70,11 +70,11 @@ func (client *gettyClientHandler) OnOpen(session getty.Session) error { | |||
} | |||
func (client *gettyClientHandler) OnError(session getty.Session, err error) { | |||
clientSessionManager.ReleaseGettySession(session) | |||
sessionManager.ReleaseGettySession(session) | |||
} | |||
func (client *gettyClientHandler) OnClose(session getty.Session) { | |||
clientSessionManager.ReleaseGettySession(session) | |||
sessionManager.ReleaseGettySession(session) | |||
} | |||
func (client *gettyClientHandler) OnMessage(session getty.Session, pkg interface{}) { | |||
@@ -1,21 +1,20 @@ | |||
package getty | |||
import ( | |||
"bytes" | |||
"encoding/binary" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"fmt" | |||
) | |||
import ( | |||
getty "github.com/apache/dubbo-getty" | |||
"github.com/pkg/errors" | |||
"github.com/fagongzi/goetty" | |||
"vimagination.zapto.org/byteio" | |||
"github.com/pkg/errors" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/codec" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
/** | |||
@@ -42,15 +41,14 @@ import ( | |||
* https://github.com/seata/seata/issues/893 | |||
*/ | |||
const ( | |||
SeataV1PackageHeaderReservedLength = 16 | |||
Seatav1HeaderLength = 16 | |||
) | |||
var ( | |||
// RpcPkgHandler | |||
magics = []uint8{0xda, 0xda} | |||
rpcPkgHandler = &RpcPackageHandler{} | |||
) | |||
// TODO 待重构 | |||
var ( | |||
ErrNotEnoughStream = errors.New("packet stream is not enough") | |||
ErrTooLargePackage = errors.New("package length is exceed the getty package's legal maximum length.") | |||
@@ -69,90 +67,50 @@ type SeataV1PackageHeader struct { | |||
MessageType message.GettyRequestType | |||
CodecType byte | |||
CompressType byte | |||
ID uint32 | |||
RequestID uint32 | |||
Meta map[string]string | |||
BodyLength uint32 | |||
Body interface{} | |||
} | |||
func (h *SeataV1PackageHeader) Unmarshal(buf *bytes.Buffer) (int, error) { | |||
bufLen := buf.Len() | |||
if bufLen < SeataV1PackageHeaderReservedLength { | |||
return 0, ErrNotEnoughStream | |||
} | |||
func (p *RpcPackageHandler) Read(ss getty.Session, data []byte) (interface{}, int, error) { | |||
in := goetty.NewByteBuf(len(data)) | |||
in.Write(data) | |||
// magic | |||
if err := binary.Read(buf, binary.BigEndian, &(h.Magic0)); err != nil { | |||
return 0, err | |||
header := SeataV1PackageHeader{} | |||
if in.Readable() < Seatav1HeaderLength { | |||
return nil, 0, fmt.Errorf("invalid package length") | |||
} | |||
if err := binary.Read(buf, binary.BigEndian, &(h.Magic1)); err != nil { | |||
return 0, err | |||
} | |||
if h.Magic0 != message.MAGIC_CODE_BYTES[0] || h.Magic1 != message.MAGIC_CODE_BYTES[1] { | |||
return 0, ErrIllegalMagic | |||
} | |||
// version | |||
if err := binary.Read(buf, binary.BigEndian, &(h.Version)); err != nil { | |||
return 0, err | |||
} | |||
// TODO check version compatible here | |||
// total length | |||
if err := binary.Read(buf, binary.BigEndian, &(h.TotalLength)); err != nil { | |||
return 0, err | |||
} | |||
// head length | |||
if err := binary.Read(buf, binary.BigEndian, &(h.HeadLength)); err != nil { | |||
return 0, err | |||
} | |||
// message type | |||
if err := binary.Read(buf, binary.BigEndian, &(h.MessageType)); err != nil { | |||
return 0, err | |||
} | |||
// codec type | |||
if err := binary.Read(buf, binary.BigEndian, &(h.CodecType)); err != nil { | |||
return 0, err | |||
magic0 := codec.ReadByte(in) | |||
magic1 := codec.ReadByte(in) | |||
if magic0 != magics[0] || magic1 != magics[1] { | |||
return nil, 0, fmt.Errorf("codec decode not found magic offset") | |||
} | |||
// compress type | |||
if err := binary.Read(buf, binary.BigEndian, &(h.CompressType)); err != nil { | |||
return 0, err | |||
} | |||
// id | |||
if err := binary.Read(buf, binary.BigEndian, &(h.ID)); err != nil { | |||
return 0, err | |||
} | |||
// todo meta map | |||
if h.HeadLength > SeataV1PackageHeaderReservedLength { | |||
headMapLength := h.HeadLength - SeataV1PackageHeaderReservedLength | |||
h.Meta = headMapDecode(buf.Bytes()[:headMapLength]) | |||
} | |||
h.BodyLength = h.TotalLength - uint32(h.HeadLength) | |||
return int(h.TotalLength), nil | |||
} | |||
header.Magic0 = magic0 | |||
header.Magic1 = magic1 | |||
header.Version = codec.ReadByte(in) | |||
// length of head and body | |||
header.TotalLength = codec.ReadUInt32(in) | |||
header.HeadLength = codec.ReadUInt16(in) | |||
header.MessageType = message.GettyRequestType(codec.ReadByte(in)) | |||
header.CodecType = codec.ReadByte(in) | |||
header.CompressType = codec.ReadByte(in) | |||
header.RequestID = codec.ReadUInt32(in) | |||
// Read read binary data from to rpc message | |||
func (p *RpcPackageHandler) Read(ss getty.Session, data []byte) (interface{}, int, error) { | |||
var header SeataV1PackageHeader | |||
headMapLength := header.HeadLength - Seatav1HeaderLength | |||
header.Meta = decodeHeapMap(in, headMapLength) | |||
header.BodyLength = header.TotalLength - uint32(header.HeadLength) | |||
buf := bytes.NewBuffer(data) | |||
_, err := header.Unmarshal(buf) | |||
if err != nil { | |||
if err == ErrNotEnoughStream { | |||
// getty case2 | |||
return nil, 0, nil | |||
} | |||
// getty case1 | |||
return nil, 0, err | |||
} | |||
if uint32(len(data)) < header.TotalLength { | |||
// get case3 | |||
return nil, int(header.TotalLength), nil | |||
} | |||
//r := byteio.BigEndianReader{Reader: bytes.NewReader(data)} | |||
rpcMessage := message.RpcMessage{ | |||
Codec: header.CodecType, | |||
ID: int32(header.ID), | |||
ID: int32(header.RequestID), | |||
Compressor: header.CompressType, | |||
Type: header.MessageType, | |||
HeadMap: header.Meta, | |||
@@ -164,8 +122,7 @@ func (p *RpcPackageHandler) Read(ss getty.Session, data []byte) (interface{}, in | |||
rpcMessage.Body = message.HeartBeatMessagePong | |||
} else { | |||
if header.BodyLength > 0 { | |||
//todo compress | |||
msg, _ := codec.MessageDecoder(header.CodecType, data[header.HeadLength:]) | |||
msg := codec.GetCodecManager().Decode(codec.CodecType(header.CodecType), data[header.HeadLength:]) | |||
rpcMessage.Body = msg | |||
} | |||
} | |||
@@ -180,101 +137,91 @@ func (p *RpcPackageHandler) Write(ss getty.Session, pkg interface{}) ([]byte, er | |||
return nil, ErrInvalidPackage | |||
} | |||
fullLength := message.V1HeadLength | |||
totalLength := message.V1HeadLength | |||
headLength := message.V1HeadLength | |||
var result = make([]byte, 0, fullLength) | |||
var b bytes.Buffer | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
result = append(result, message.MAGIC_CODE_BYTES[:2]...) | |||
result = append(result, message.VERSION) | |||
w.WriteByte(byte(msg.Type)) | |||
w.WriteByte(msg.Codec) | |||
w.WriteByte(msg.Compressor) | |||
w.WriteInt32(msg.ID) | |||
var headMapBytes []byte | |||
if msg.HeadMap != nil && len(msg.HeadMap) > 0 { | |||
headMapBytes, headMapLength := headMapEncode(msg.HeadMap) | |||
hb, headMapLength := encodeHeapMap(msg.HeadMap) | |||
headMapBytes = hb | |||
headLength += headMapLength | |||
fullLength += headMapLength | |||
w.Write(headMapBytes) | |||
totalLength += headMapLength | |||
} | |||
var bodyBytes []byte | |||
if msg.Type != message.GettyRequestType_HeartbeatRequest && | |||
msg.Type != message.GettyRequestType_HeartbeatResponse { | |||
bodyBytes := codec.MessageEncoder(msg.Codec, msg.Body) | |||
fullLength += len(bodyBytes) | |||
w.Write(bodyBytes) | |||
} | |||
fullLen := int32(fullLength) | |||
headLen := int16(headLength) | |||
result = append(result, []byte{byte(fullLen >> 24), byte(fullLen >> 16), byte(fullLen >> 8), byte(fullLen)}...) | |||
result = append(result, []byte{byte(headLen >> 8), byte(headLen)}...) | |||
result = append(result, b.Bytes()...) | |||
return result, nil | |||
bodyBytes = codec.GetCodecManager().Encode(codec.CodecType(msg.Codec), msg.Body) | |||
totalLength += len(bodyBytes) | |||
} | |||
buf := goetty.NewByteBuf(0) | |||
buf.WriteByte(message.MAGIC_CODE_BYTES[0]) | |||
buf.WriteByte(message.MAGIC_CODE_BYTES[1]) | |||
buf.WriteByte(message.VERSION) | |||
buf.WriteUInt32(uint32(totalLength)) | |||
buf.WriteUInt16(uint16(headLength)) | |||
buf.WriteByte(byte(msg.Type)) | |||
buf.WriteByte(msg.Codec) | |||
buf.WriteByte(msg.Compressor) | |||
buf.WriteUInt32(uint32(msg.ID)) | |||
buf.Write(headMapBytes) | |||
buf.Write(bodyBytes) | |||
return buf.RawBuf(), nil | |||
} | |||
func headMapDecode(data []byte) map[string]string { | |||
size := len(data) | |||
if size == 0 { | |||
return nil | |||
} | |||
mp := make(map[string]string) | |||
r := byteio.BigEndianReader{Reader: bytes.NewReader(data)} | |||
readLength := 0 | |||
for readLength < size { | |||
var key, value string | |||
lengthK, _, _ := r.ReadUint16() | |||
if lengthK < 0 { | |||
break | |||
} else if lengthK == 0 { | |||
key = "" | |||
func encodeHeapMap(data map[string]string) ([]byte, int) { | |||
buf := goetty.NewByteBuf(0) | |||
for k, v := range data { | |||
if k == "" { | |||
buf.WriteUInt16(uint16(0)) | |||
} else { | |||
key, _, _ = r.ReadString(int(lengthK)) | |||
buf.WriteUInt16(uint16(len(k))) | |||
buf.WriteString(k) | |||
} | |||
lengthV, _, _ := r.ReadUint16() | |||
if lengthV < 0 { | |||
break | |||
} else if lengthV == 0 { | |||
value = "" | |||
if v == "" { | |||
buf.WriteUInt16(uint16(0)) | |||
} else { | |||
value, _, _ = r.ReadString(int(lengthV)) | |||
buf.WriteUInt16(uint16(len(v))) | |||
buf.WriteString(v) | |||
} | |||
mp[key] = value | |||
readLength += int(lengthK + lengthV) | |||
} | |||
return mp | |||
res := buf.RawBuf() | |||
return res, len(res) | |||
} | |||
func headMapEncode(data map[string]string) ([]byte, int) { | |||
var b bytes.Buffer | |||
func decodeHeapMap(in *goetty.ByteBuf, length uint16) map[string]string { | |||
res := make(map[string]string, 0) | |||
if length == 0 { | |||
return res | |||
} | |||
w := byteio.BigEndianWriter{Writer: &b} | |||
for k, v := range data { | |||
if k == "" { | |||
w.WriteUint16(0) | |||
readedLength := uint16(0) | |||
for readedLength < length { | |||
var key, value string | |||
keyLength := codec.ReadUInt16(in) | |||
if keyLength == 0 { | |||
key = "" | |||
} else { | |||
w.WriteUint16(uint16(len(k))) | |||
w.WriteString(k) | |||
keyBytes := make([]byte, keyLength) | |||
keyBytes = codec.Read(in, keyBytes) | |||
key = string(keyBytes) | |||
} | |||
if v == "" { | |||
w.WriteUint16(0) | |||
valueLength := codec.ReadUInt16(in) | |||
if valueLength == 0 { | |||
key = "" | |||
} else { | |||
w.WriteUint16(uint16(len(v))) | |||
w.WriteString(v) | |||
valueBytes := make([]byte, valueLength) | |||
valueBytes = codec.Read(in, valueBytes) | |||
value = string(valueBytes) | |||
} | |||
} | |||
return b.Bytes(), b.Len() | |||
res[key] = value | |||
readedLength += 4 + keyLength + valueLength | |||
fmt.Sprintln("done") | |||
} | |||
return res | |||
} |
@@ -1,8 +1,8 @@ | |||
package getty | |||
import ( | |||
"crypto/tls" | |||
"fmt" | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"net" | |||
"sync" | |||
) | |||
@@ -11,9 +11,12 @@ import ( | |||
getty "github.com/apache/dubbo-getty" | |||
gxsync "github.com/dubbogo/gost/sync" | |||
"github.com/pkg/errors" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/config" | |||
) | |||
@@ -62,23 +65,51 @@ func (c *RpcClient) newSession(session getty.Session) error { | |||
var ( | |||
ok bool | |||
tcpConn *net.TCPConn | |||
err error | |||
) | |||
if c.conf.GettyConfig.GettySessionParam.CompressEncoding { | |||
session.SetCompressType(getty.CompressZip) | |||
} | |||
if tcpConn, ok = session.Conn().(*net.TCPConn); !ok { | |||
panic(fmt.Sprintf("%s, session.conn{%#v} is not tcp connection\n", session.Stat(), session.Conn())) | |||
if _, ok = session.Conn().(*tls.Conn); ok { | |||
session.SetName(c.conf.GettyConfig.GettySessionParam.SessionName) | |||
session.SetMaxMsgLen(c.conf.GettyConfig.GettySessionParam.MaxMsgLen) | |||
session.SetPkgHandler(rpcPkgHandler) | |||
session.SetEventListener(GetGettyClientHandlerInstance()) | |||
session.SetReadTimeout(c.conf.GettyConfig.GettySessionParam.TCPReadTimeout) | |||
session.SetWriteTimeout(c.conf.GettyConfig.GettySessionParam.TCPWriteTimeout) | |||
session.SetCronPeriod((int)(c.conf.GettyConfig.GettySessionParam.CronPeriod)) | |||
session.SetWaitTime(c.conf.GettyConfig.GettySessionParam.WaitTimeout) | |||
log.Debugf("server accepts new tls session:%s\n", session.Stat()) | |||
return nil | |||
} | |||
if _, ok = session.Conn().(*net.TCPConn); !ok { | |||
panic(fmt.Sprintf("%s, session.conn{%#v} is not a tcp connection\n", session.Stat(), session.Conn())) | |||
} | |||
tcpConn.SetNoDelay(c.conf.GettyConfig.GettySessionParam.TCPNoDelay) | |||
tcpConn.SetKeepAlive(c.conf.GettyConfig.GettySessionParam.TCPKeepAlive) | |||
if c.conf.GettyConfig.GettySessionParam.TCPKeepAlive { | |||
tcpConn.SetKeepAlivePeriod(c.conf.GettyConfig.GettySessionParam.KeepAlivePeriod) | |||
if _, ok = session.Conn().(*tls.Conn); !ok { | |||
if tcpConn, ok = session.Conn().(*net.TCPConn); !ok { | |||
return errors.New(fmt.Sprintf("%s, session.conn{%#v} is not tcp connection", session.Stat(), session.Conn())) | |||
} | |||
if err = tcpConn.SetNoDelay(c.conf.GettyConfig.GettySessionParam.TCPNoDelay); err != nil { | |||
return err | |||
} | |||
if err = tcpConn.SetKeepAlive(c.conf.GettyConfig.GettySessionParam.TCPKeepAlive); err != nil { | |||
return err | |||
} | |||
if c.conf.GettyConfig.GettySessionParam.TCPKeepAlive { | |||
if err = tcpConn.SetKeepAlivePeriod(c.conf.GettyConfig.GettySessionParam.KeepAlivePeriod); err != nil { | |||
return err | |||
} | |||
} | |||
if err = tcpConn.SetReadBuffer(c.conf.GettyConfig.GettySessionParam.TCPRBufSize); err != nil { | |||
return err | |||
} | |||
if err = tcpConn.SetWriteBuffer(c.conf.GettyConfig.GettySessionParam.TCPWBufSize); err != nil { | |||
return err | |||
} | |||
} | |||
tcpConn.SetReadBuffer(c.conf.GettyConfig.GettySessionParam.TCPRBufSize) | |||
tcpConn.SetWriteBuffer(c.conf.GettyConfig.GettySessionParam.TCPWBufSize) | |||
session.SetName(c.conf.GettyConfig.GettySessionParam.SessionName) | |||
session.SetMaxMsgLen(c.conf.GettyConfig.GettySessionParam.MaxMsgLen) | |||
@@ -86,7 +117,7 @@ func (c *RpcClient) newSession(session getty.Session) error { | |||
session.SetEventListener(GetGettyClientHandlerInstance()) | |||
session.SetReadTimeout(c.conf.GettyConfig.GettySessionParam.TCPReadTimeout) | |||
session.SetWriteTimeout(c.conf.GettyConfig.GettySessionParam.TCPWriteTimeout) | |||
session.SetCronPeriod((int)(c.conf.GettyConfig.HeartbeatPeriod.Nanoseconds() / 1e6)) | |||
session.SetCronPeriod((int)(c.conf.GettyConfig.GettySessionParam.CronPeriod.Nanoseconds() / 1e6)) | |||
session.SetWaitTime(c.conf.GettyConfig.GettySessionParam.WaitTimeout) | |||
log.Debugf("rpc_client new session:%s\n", session.Stat()) | |||
@@ -22,12 +22,12 @@ var ( | |||
sessionSize int32 = 0 | |||
clientSessionManager = &GettyClientSessionManager{} | |||
sessionManager = &GettySessionManager{} | |||
) | |||
type GettyClientSessionManager struct{} | |||
type GettySessionManager struct{} | |||
func (sessionManager *GettyClientSessionManager) AcquireGettySession() getty.Session { | |||
func (sessionManager *GettySessionManager) AcquireGettySession() getty.Session { | |||
// map 遍历是随机的 | |||
var session getty.Session | |||
allSessions.Range(func(key, value interface{}) bool { | |||
@@ -64,7 +64,7 @@ func (sessionManager *GettyClientSessionManager) AcquireGettySession() getty.Ses | |||
return nil | |||
} | |||
func (sessionManager *GettyClientSessionManager) AcquireGettySessionByServerAddress(serverAddress string) getty.Session { | |||
func (sessionManager *GettySessionManager) AcquireGettySessionByServerAddress(serverAddress string) getty.Session { | |||
m, _ := serverSessions.LoadOrStore(serverAddress, &sync.Map{}) | |||
sMap := m.(*sync.Map) | |||
@@ -81,7 +81,7 @@ func (sessionManager *GettyClientSessionManager) AcquireGettySessionByServerAddr | |||
return session | |||
} | |||
func (sessionManager *GettyClientSessionManager) ReleaseGettySession(session getty.Session) { | |||
func (sessionManager *GettySessionManager) ReleaseGettySession(session getty.Session) { | |||
allSessions.Delete(session) | |||
if !session.IsClosed() { | |||
m, _ := serverSessions.LoadOrStore(session.RemoteAddr(), &sync.Map{}) | |||
@@ -92,7 +92,7 @@ func (sessionManager *GettyClientSessionManager) ReleaseGettySession(session get | |||
atomic.AddInt32(&sessionSize, -1) | |||
} | |||
func (sessionManager *GettyClientSessionManager) RegisterGettySession(session getty.Session) { | |||
func (sessionManager *GettySessionManager) RegisterGettySession(session getty.Session) { | |||
allSessions.Store(session, true) | |||
m, _ := serverSessions.LoadOrStore(session.RemoteAddr(), &sync.Map{}) | |||
sMap := m.(*sync.Map) | |||
@@ -2,6 +2,9 @@ package client | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/remoting/getty" | |||
@@ -2,6 +2,9 @@ package client | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
getty2 "github.com/seata/seata-go/pkg/remoting/getty" | |||
@@ -2,6 +2,9 @@ package client | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
getty2 "github.com/seata/seata-go/pkg/remoting/getty" | |||
@@ -2,6 +2,9 @@ package client | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
getty2 "github.com/seata/seata-go/pkg/remoting/getty" | |||
@@ -2,6 +2,9 @@ package processor | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
) | |||
@@ -2,6 +2,9 @@ package handler | |||
import ( | |||
"context" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
@@ -2,9 +2,12 @@ package handler | |||
import ( | |||
"context" | |||
"sync" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"sync" | |||
) | |||
var ( | |||
@@ -1,12 +1,15 @@ | |||
package remoting | |||
import ( | |||
"sync" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/resource" | |||
"github.com/seata/seata-go/pkg/remoting/getty" | |||
"sync" | |||
) | |||
var ( | |||
@@ -3,9 +3,12 @@ package rm | |||
import ( | |||
"context" | |||
"fmt" | |||
"sync" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/resource" | |||
"sync" | |||
) | |||
var ( | |||
@@ -3,18 +3,18 @@ package tcc | |||
import ( | |||
"context" | |||
"fmt" | |||
"sync" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/protocol/resource" | |||
"github.com/seata/seata-go/pkg/remoting/getty" | |||
"github.com/seata/seata-go/pkg/rm" | |||
"github.com/seata/seata-go/pkg/rm/common/remoting" | |||
"github.com/seata/seata-go/pkg/rm/tcc/api" | |||
"sync" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/rm" | |||
) | |||
var ( | |||
@@ -111,7 +111,7 @@ func (t *TCCResourceManager) GetManagedResources() sync.Map { | |||
func (t *TCCResourceManager) BranchCommit(ctx context.Context, ranchType branch.BranchType, xid string, branchID int64, resourceID string, applicationData []byte) (branch.BranchStatus, error) { | |||
var tccResource *TCCResource | |||
if resource, ok := t.resourceManagerMap.Load(resourceID); !ok { | |||
err := fmt.Errorf("CC resource is not exist, resourceId: %s", resourceID) | |||
err := fmt.Errorf("TCC resource is not exist, resourceId: %s", resourceID) | |||
return 0, err | |||
} else { | |||
tccResource, _ = resource.(*TCCResource) | |||
@@ -4,20 +4,22 @@ import ( | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"time" | |||
) | |||
import ( | |||
"github.com/pkg/errors" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common" | |||
"github.com/seata/seata-go/pkg/common/log" | |||
"github.com/seata/seata-go/pkg/common/net" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/seatactx" | |||
context2 "github.com/seata/seata-go/pkg/protocol/transaction" | |||
"github.com/seata/seata-go/pkg/protocol/transaction/executor" | |||
"github.com/seata/seata-go/pkg/rm" | |||
api2 "github.com/seata/seata-go/pkg/rm/tcc/api" | |||
"time" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/rm/tcc/remoting" | |||
) | |||
type TCCService interface { | |||
@@ -26,8 +28,8 @@ type TCCService interface { | |||
Rollback(ctx context.Context, businessActionContext api2.BusinessActionContext) error | |||
GetActionName() string | |||
GetRemoteType() remoting.RemoteType | |||
GetServiceType() remoting.ServiceType | |||
//GetRemoteType() remoting.RemoteType | |||
//GetServiceType() remoting.ServiceType | |||
} | |||
type TCCServiceProxy struct { | |||
@@ -57,30 +59,18 @@ func NewTCCServiceProxy(tccService TCCService) TCCService { | |||
} | |||
func (t *TCCServiceProxy) Prepare(ctx context.Context, param interface{}) error { | |||
var err error | |||
if context2.IsSeataContext(ctx) { | |||
// execute transaction | |||
_, err = executor.GetTransactionTemplate().Execute(ctx, t, param) | |||
} else { | |||
log.Warn("context is not inited as seata context, will not execute transaction!") | |||
err = t.TCCService.Prepare(ctx, param) | |||
} | |||
return err | |||
} | |||
// register transaction branch, and then execute business | |||
func (t *TCCServiceProxy) Execute(ctx context.Context, param interface{}) (interface{}, error) { | |||
// register transaction branch | |||
err := t.RegisteBranch(ctx, param) | |||
if err != nil { | |||
return nil, err | |||
if seatactx.HasXID(ctx) { | |||
err := t.RegisteBranch(ctx, param) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
return nil, t.TCCService.Prepare(ctx, param) | |||
return t.TCCService.Prepare(ctx, param) | |||
} | |||
func (t *TCCServiceProxy) RegisteBranch(ctx context.Context, param interface{}) error { | |||
// register transaction branch | |||
if !context2.HasXID(ctx) { | |||
if !seatactx.HasXID(ctx) { | |||
err := errors.New("BranchRegister error, xid should not be nil") | |||
log.Errorf(err.Error()) | |||
return err | |||
@@ -91,7 +81,7 @@ func (t *TCCServiceProxy) RegisteBranch(ctx context.Context, param interface{}) | |||
tccContextStr, _ := json.Marshal(tccContext) | |||
branchId, err := rm.GetResourceManagerInstance().GetResourceManager(branch.BranchTypeTCC).BranchRegister( | |||
ctx, branch.BranchTypeTCC, t.GetActionName(), "", context2.GetXID(ctx), string(tccContextStr), "") | |||
ctx, branch.BranchTypeTCC, t.GetActionName(), "", seatactx.GetXID(ctx), string(tccContextStr), "") | |||
if err != nil { | |||
err = errors.New(fmt.Sprintf("BranchRegister error: %v", err.Error())) | |||
log.Error(err.Error()) | |||
@@ -99,12 +89,12 @@ func (t *TCCServiceProxy) RegisteBranch(ctx context.Context, param interface{}) | |||
} | |||
actionContext := &api2.BusinessActionContext{ | |||
Xid: context2.GetXID(ctx), | |||
Xid: seatactx.GetXID(ctx), | |||
BranchId: string(branchId), | |||
ActionName: t.GetActionName(), | |||
ActionContext: param, | |||
} | |||
context2.SetBusinessActionContext(ctx, actionContext) | |||
seatactx.SetBusinessActionContext(ctx, actionContext) | |||
return nil | |||
} | |||
@@ -1,11 +1,14 @@ | |||
package test | |||
import ( | |||
_ "github.com/seata/seata-go/pkg/imports" | |||
"testing" | |||
"time" | |||
) | |||
import ( | |||
_ "github.com/seata/seata-go/pkg/imports" | |||
) | |||
func TestSendMsgWithResponse(test *testing.T) { | |||
//request := protocol.RegisterRMRequest{ | |||
// ResourceIds: "1111", | |||
@@ -2,13 +2,16 @@ package test | |||
import ( | |||
"context" | |||
"testing" | |||
"time" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/common/log" | |||
_ "github.com/seata/seata-go/pkg/imports" | |||
context2 "github.com/seata/seata-go/pkg/protocol/transaction" | |||
txapi "github.com/seata/seata-go/pkg/protocol/transaction/api" | |||
"github.com/seata/seata-go/pkg/rm/tcc" | |||
"github.com/seata/seata-go/pkg/rm/tcc/api" | |||
"github.com/seata/seata-go/pkg/rm/tcc/remoting" | |||
"testing" | |||
) | |||
type TestTCCServiceBusiness struct { | |||
@@ -33,17 +36,41 @@ func (T TestTCCServiceBusiness) GetActionName() string { | |||
return "TestTCCServiceBusiness" | |||
} | |||
func (T TestTCCServiceBusiness) GetRemoteType() remoting.RemoteType { | |||
return remoting.RemoteTypeLocalService | |||
type TestTCCServiceBusiness2 struct { | |||
} | |||
func (T TestTCCServiceBusiness2) Prepare(ctx context.Context, params interface{}) error { | |||
log.Infof("TestTCCServiceBusiness2 Prepare, param %v", params) | |||
return nil | |||
} | |||
func (T TestTCCServiceBusiness2) Commit(ctx context.Context, businessActionContext api.BusinessActionContext) error { | |||
log.Infof("TestTCCServiceBusiness2 Commit, param %v", businessActionContext) | |||
return nil | |||
} | |||
func (T TestTCCServiceBusiness2) Rollback(ctx context.Context, businessActionContext api.BusinessActionContext) error { | |||
log.Infof("TestTCCServiceBusiness2 Rollback, param %v", businessActionContext) | |||
return nil | |||
} | |||
func (T TestTCCServiceBusiness) GetServiceType() remoting.ServiceType { | |||
return remoting.ServiceTypeProvider | |||
func (T TestTCCServiceBusiness2) GetActionName() string { | |||
return "TestTCCServiceBusiness2" | |||
} | |||
func TestNew(test *testing.T) { | |||
var err error | |||
ctx := txapi.Begin(context.Background(), "TestTCCServiceBusiness") | |||
defer func() { | |||
resp := txapi.CommitOrRollback(ctx, err) | |||
log.Infof("tx result %v", resp) | |||
}() | |||
tccService := tcc.NewTCCServiceProxy(TestTCCServiceBusiness{}) | |||
tccService.Prepare(context2.InitSeataContext(context.Background()), 1) | |||
err = tccService.Prepare(ctx, 1) | |||
tccService2 := tcc.NewTCCServiceProxy(TestTCCServiceBusiness2{}) | |||
err = tccService2.Prepare(ctx, 3) | |||
//time.Sleep(time.Second * 1000) | |||
time.Sleep(time.Second * 1000) | |||
} |
@@ -1,51 +0,0 @@ | |||
package mock | |||
import ( | |||
"context" | |||
"fmt" | |||
"github.com/seata/seata-go/pkg/common/xid" | |||
"github.com/seata/seata-go/pkg/rm/tcc/api" | |||
) | |||
import ( | |||
"github.com/seata/seata-go/pkg/rm/tcc/remoting" | |||
_ "github.com/seata/seata-go/pkg/utils/xid" | |||
) | |||
// 注册RM资源 | |||
func init() { | |||
} | |||
type MockTccService struct { | |||
} | |||
func (*MockTccService) Prepare(ctx context.Context, params interface{}) error { | |||
xid := xid_utils.xid_utils.GetXID(ctx) | |||
fmt.Printf("TccActionOne prepare, xid:" + xid) | |||
return nil | |||
} | |||
func (*MockTccService) Commit(ctx context.Context, businessActionContext api.BusinessActionContext) error { | |||
xid := xid_utils.GetXID(ctx) | |||
fmt.Printf("TccActionOne commit, xid:" + xid) | |||
return nil | |||
} | |||
func (*MockTccService) Rollback(ctx context.Context, businessActionContext api.BusinessActionContext) error { | |||
xid := xid_utils.GetXID(ctx) | |||
fmt.Printf("TccActionOne rollback, xid:" + xid) | |||
return nil | |||
} | |||
func (*MockTccService) GetRemoteType() remoting.RemoteType { | |||
return remoting.RemoteTypeLocalService | |||
} | |||
func (*MockTccService) GetActionName() string { | |||
return "MockTccService" | |||
} | |||
func (*MockTccService) GetServiceType() remoting.ServiceType { | |||
return remoting.ServiceTypeProvider | |||
} |